diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..240447c073b0e898646f6ccb5f9945efebb6f9ef 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.gif filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..1625c1793607996fcfc46420e8aa2f3d2b7efd1e
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,121 @@
+Creative Commons Legal Code
+
+CC0 1.0 Universal
+
+ CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE
+ LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN
+ ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS
+ INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES
+ REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS
+ PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM
+ THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED
+ HEREUNDER.
+
+Statement of Purpose
+
+The laws of most jurisdictions throughout the world automatically confer
+exclusive Copyright and Related Rights (defined below) upon the creator
+and subsequent owner(s) (each and all, an "owner") of an original work of
+authorship and/or a database (each, a "Work").
+
+Certain owners wish to permanently relinquish those rights to a Work for
+the purpose of contributing to a commons of creative, cultural and
+scientific works ("Commons") that the public can reliably and without fear
+of later claims of infringement build upon, modify, incorporate in other
+works, reuse and redistribute as freely as possible in any form whatsoever
+and for any purposes, including without limitation commercial purposes.
+These owners may contribute to the Commons to promote the ideal of a free
+culture and the further production of creative, cultural and scientific
+works, or to gain reputation or greater distribution for their Work in
+part through the use and efforts of others.
+
+For these and/or other purposes and motivations, and without any
+expectation of additional consideration or compensation, the person
+associating CC0 with a Work (the "Affirmer"), to the extent that he or she
+is an owner of Copyright and Related Rights in the Work, voluntarily
+elects to apply CC0 to the Work and publicly distribute the Work under its
+terms, with knowledge of his or her Copyright and Related Rights in the
+Work and the meaning and intended legal effect of CC0 on those rights.
+
+1. Copyright and Related Rights. A Work made available under CC0 may be
+protected by copyright and related or neighboring rights ("Copyright and
+Related Rights"). Copyright and Related Rights include, but are not
+limited to, the following:
+
+ i. the right to reproduce, adapt, distribute, perform, display,
+ communicate, and translate a Work;
+ ii. moral rights retained by the original author(s) and/or performer(s);
+iii. publicity and privacy rights pertaining to a person's image or
+ likeness depicted in a Work;
+ iv. rights protecting against unfair competition in regards to a Work,
+ subject to the limitations in paragraph 4(a), below;
+ v. rights protecting the extraction, dissemination, use and reuse of data
+ in a Work;
+ vi. database rights (such as those arising under Directive 96/9/EC of the
+ European Parliament and of the Council of 11 March 1996 on the legal
+ protection of databases, and under any national implementation
+ thereof, including any amended or successor version of such
+ directive); and
+vii. other similar, equivalent or corresponding rights throughout the
+ world based on applicable law or treaty, and any national
+ implementations thereof.
+
+2. Waiver. To the greatest extent permitted by, but not in contravention
+of, applicable law, Affirmer hereby overtly, fully, permanently,
+irrevocably and unconditionally waives, abandons, and surrenders all of
+Affirmer's Copyright and Related Rights and associated claims and causes
+of action, whether now known or unknown (including existing as well as
+future claims and causes of action), in the Work (i) in all territories
+worldwide, (ii) for the maximum duration provided by applicable law or
+treaty (including future time extensions), (iii) in any current or future
+medium and for any number of copies, and (iv) for any purpose whatsoever,
+including without limitation commercial, advertising or promotional
+purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each
+member of the public at large and to the detriment of Affirmer's heirs and
+successors, fully intending that such Waiver shall not be subject to
+revocation, rescission, cancellation, termination, or any other legal or
+equitable action to disrupt the quiet enjoyment of the Work by the public
+as contemplated by Affirmer's express Statement of Purpose.
+
+3. Public License Fallback. Should any part of the Waiver for any reason
+be judged legally invalid or ineffective under applicable law, then the
+Waiver shall be preserved to the maximum extent permitted taking into
+account Affirmer's express Statement of Purpose. In addition, to the
+extent the Waiver is so judged Affirmer hereby grants to each affected
+person a royalty-free, non transferable, non sublicensable, non exclusive,
+irrevocable and unconditional license to exercise Affirmer's Copyright and
+Related Rights in the Work (i) in all territories worldwide, (ii) for the
+maximum duration provided by applicable law or treaty (including future
+time extensions), (iii) in any current or future medium and for any number
+of copies, and (iv) for any purpose whatsoever, including without
+limitation commercial, advertising or promotional purposes (the
+"License"). The License shall be deemed effective as of the date CC0 was
+applied by Affirmer to the Work. Should any part of the License for any
+reason be judged legally invalid or ineffective under applicable law, such
+partial invalidity or ineffectiveness shall not invalidate the remainder
+of the License, and in such case Affirmer hereby affirms that he or she
+will not (i) exercise any of his or her remaining Copyright and Related
+Rights in the Work or (ii) assert any associated claims and causes of
+action with respect to the Work, in either case contrary to Affirmer's
+express Statement of Purpose.
+
+4. Limitations and Disclaimers.
+
+ a. No trademark or patent rights held by Affirmer are waived, abandoned,
+ surrendered, licensed or otherwise affected by this document.
+ b. Affirmer offers the Work as-is and makes no representations or
+ warranties of any kind concerning the Work, express, implied,
+ statutory or otherwise, including without limitation warranties of
+ title, merchantability, fitness for a particular purpose, non
+ infringement, or the absence of latent or other defects, accuracy, or
+ the present or absence of errors, whether or not discoverable, all to
+ the greatest extent permissible under applicable law.
+ c. Affirmer disclaims responsibility for clearing rights of other persons
+ that may apply to the Work or any use thereof, including without
+ limitation any person's Copyright and Related Rights in the Work.
+ Further, Affirmer disclaims responsibility for obtaining any necessary
+ consents, permissions or other rights required for any use of the
+ Work.
+ d. Affirmer understands and acknowledges that Creative Commons is not a
+ party to this document and has no duty or obligation with respect to
+ this CC0 or use of the Work.
\ No newline at end of file
diff --git a/README.md b/README.md
index 3e9a1fca9d33c9ee98d2011884ef83fb2df905ce..7d99b1296f01a87707ed1ab57c6a70e20e6634f6 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,319 @@
----
-license: cc0-1.0
----
+# π Metric3D Project π
+
+**Official PyTorch implementation of Metric3Dv1 and Metric3Dv2:**
+
+[1] [Metric3D: Towards Zero-shot Metric 3D Prediction from A Single Image](https://arxiv.org/abs/2307.10984)
+
+[2] Metric3Dv2: A Versatile Monocular Geometric Foundation Model for Zero-shot Metric Depth and Surface Normal Estimation
+
+
+
+
+
+
+[//]: # (### [Project Page](https://arxiv.org/abs/2307.08695) | [v2 Paper](https://arxiv.org/abs/2307.10984) | [v1 Arxiv](https://arxiv.org/abs/2307.10984) | [Video](https://www.youtube.com/playlist?list=PLEuyXJsWqUNd04nwfm9gFBw5FVbcaQPl3) | [Hugging Face π€](https://huggingface.co/spaces/JUGGHM/Metric3D) )
+
+## News and TO DO LIST
+
+- [ ] Droid slam codes
+- [ ] Release the ViT-giant2 model
+- [ ] Focal length free mode
+- [ ] Floating noise removing mode
+- [ ] Improving HuggingFace Demo and Visualization
+- [x] Release training codes
+
+- `[2024/3/18]` HuggingFace GPU version updated!
+- `[2024/3/18]` [Project page](https://jugghm.github.io/Metric3Dv2/) released!
+- `[2024/3/18]` Metric3D V2 models released, supporting metric depth and surface normal now!
+- `[2023/8/10]` Inference codes, pretrained weights, and demo released.
+- `[2023/7]` Metric3D accepted by ICCV 2023!
+- `[2023/4]` The Champion of [2nd Monocular Depth Estimation Challenge](https://jspenmar.github.io/MDEC) in CVPR 2023
+
+## πΌ Abstract
+Metric3D is a versatile geometric foundation model for high-quality and zero-shot **metric depth** and **surface normal** estimation from a single image. It excels at solving in-the-wild scene reconstruction.
+
+
+
+
+
+## π Benchmarks
+
+### Metric Depth
+
+[//]: # (#### Zero-shot Testing)
+
+[//]: # (Our models work well on both indoor and outdoor scenarios, compared with other zero-shot metric depth estimation methods.)
+
+[//]: # ()
+[//]: # (| | Backbone | KITTI $\delta 1$ β | KITTI $\delta 2$ β | KITTI $\delta 3$ β | KITTI AbsRel β | KITTI RMSE β | KITTI RMS_log β | NYU $\delta 1$ β | NYU $\delta 2$ β | NYU $\delta 3$ β | NYU AbsRel β | NYU RMSE β | NYU log10 β |)
+
+[//]: # (|-----------------|------------|--------------------|---------------------|--------------------|-----------------|---------------|------------------|------------------|------------------|------------------|---------------|-------------|--------------|)
+
+[//]: # (| ZeroDepth | ResNet-18 | 0.910 | 0.980 | 0.996 | 0.057 | 4.044 | 0.083 | 0.901 | 0.961 | - | 0.100 | 0.380 | - |)
+
+[//]: # (| PolyMax | ConvNeXt-L | - | - | - | - | - | - | 0.969 | 0.996 | 0.999 | 0.067 | 0.250 | 0.033 |)
+
+[//]: # (| Ours | ViT-L | 0.985 | 0.995 | 0.999 | 0.052 | 2.511 | 0.074 | 0.975 | 0.994 | 0.998 | 0.063 | 0.251 | 0.028 |)
+
+[//]: # (| Ours | ViT-g2 | 0.989 | 0.996 | 0.999 | 0.051 | 2.403 | 0.080 | 0.980 | 0.997 | 0.999 | 0.067 | 0.260 | 0.030 |)
+
+[//]: # ()
+[//]: # ([//]: # (| Adabins | Efficient-B5 | 0.964 | 0.995 | 0.999 | 0.058 | 2.360 | 0.088 | 0.903 | 0.984 | 0.997 | 0.103 | 0.0444 | 0.364 |))
+[//]: # ([//]: # (| NewCRFs | SwinT-L | 0.974 | 0.997 | 0.999 | 0.052 | 2.129 | 0.079 | 0.922 | 0.983 | 0.994 | 0.095 | 0.041 | 0.334 |))
+[//]: # ([//]: # (| Ours (CSTM_label) | ConvNeXt-L | 0.964 | 0.993 | 0.998 | 0.058 | 2.770 | 0.092 | 0.944 | 0.986 | 0.995 | 0.083 | 0.035 | 0.310 |))
+
+[//]: # (#### Finetuned)
+Our models rank 1st on the routing KITTI and NYU benchmarks.
+
+| | Backbone | KITTI Ξ΄1 β | KITTI Ξ΄2 β | KITTI AbsRel β | KITTI RMSE β | KITTI RMS_log β | NYU Ξ΄1 β | NYU Ξ΄2 β | NYU AbsRel β | NYU RMSE β | NYU log10 β |
+|---------------|-------------|------------|-------------|-----------------|---------------|------------------|----------|----------|---------------|-------------|--------------|
+| ZoeDepth | ViT-Large | 0.971 | 0.995 | 0.053 | 2.281 | 0.082 | 0.953 | 0.995 | 0.077 | 0.277 | 0.033 |
+| ZeroDepth | ResNet-18 | 0.968 | 0.996 | 0.057 | 2.087 | 0.083 | 0.954 | 0.995 | 0.074 | 0.269 | 0.103 |
+| IEBins | SwinT-Large | 0.978 | 0.998 | 0.050 | 2.011 | 0.075 | 0.936 | 0.992 | 0.087 | 0.314 | 0.031 |
+| DepthAnything | ViT-Large | 0.982 | 0.998 | 0.046 | 1.985 | 0.069 | 0.984 | 0.998 | 0.056 | 0.206 | 0.024 |
+| Ours | ViT-Large | 0.985 | 0.998 | 0.999 | 1.985 | 0.064 | 0.989 | 0.998 | 0.047 | 0.183 | 0.020 |
+| Ours | ViT-giant2 | 0.989 | 0.998 | 1.000 | 1.766 | 0.060 | 0.987 | 0.997 | 0.045 | 0.187 | 0.015 |
+
+### Affine-invariant Depth
+Even compared to recent affine-invariant depth methods (Marigold and Depth Anything), our metric-depth (and normal) models still show superior performance.
+
+| | #Data for Pretrain and Train | KITTI Absrel β | KITTI Ξ΄1 β | NYUv2 AbsRel β | NYUv2 Ξ΄1 β | DIODE-Full AbsRel β | DIODE-Full Ξ΄1 β | Eth3d AbsRel β | Eth3d Ξ΄1 β |
+|-----------------------|----------------------------------------------|----------------|------------|-----------------|------------|---------------------|-----------------|----------------------|------------|
+| OmniData (v2, ViT-L) | 1.3M + 12.2M | 0.069 | 0.948 | 0.074 | 0.945 | 0.149 | 0.835 | 0.166 | 0.778 |
+| MariGold (LDMv2) | 5B + 74K | 0.099 | 0.916 | 0.055 | 0.961 | 0.308 | 0.773 | 0.127 | 0.960 |
+| DepthAnything (ViT-L) | 142M + 63M | 0.076 | 0.947 | 0.043 | 0.981 | 0.277 | 0.759 | 0.065 | 0.882 |
+| Ours (ViT-L) | 142M + 16M | 0.042 | 0.979 | 0.042 | 0.980 | 0.141 | 0.882 | 0.042 | 0.987 |
+| Ours (ViT-g) | 142M + 16M | 0.043 | 0.982 | 0.043 | 0.981 | 0.136 | 0.895 | 0.042 | 0.983 |
+
+
+### Surface Normal
+Our models also show powerful performance on normal benchmarks.
+
+| | NYU 11.25Β° β | NYU Mean β | NYU RMS β | ScanNet 11.25Β° β | ScanNet Mean β | ScanNet RMS β | iBims 11.25Β° β | iBims Mean β | iBims RMS β |
+|--------------|----------|----------|-----------|-----------------|----------------|--------------|---------------|--------------|-------------|
+| EESNU | 0.597 | 16.0 | 24.7 | 0.711 | 11.8 | 20.3 | 0.585 | 20.0 | - |
+| IronDepth | - | - | - | - | - | - | 0.431 | 25.3 | 37.4 |
+| PolyMax | 0.656 | 13.1 | 20.4 | - | - | - | - | - | - |
+| Ours (ViT-L) | 0.688 | 12.0 | 19.2 | 0.760 | 9.9 | 16.4 | 0.694 | 19.4 | 34.9 |
+| Ours (ViT-g) | 0.662 | 13.2 | 20.2 | 0.778 | 9.2 | 15.3 | 0.697 | 19.6 | 35.2 |
+
+
+
+## π DEMOs
+
+### Zero-shot monocular metric depth & surface normal
+
+
+
+### Zero-shot metric 3D recovery
+
+
+### Improving monocular SLAM
+
+
+[//]: # (https://github.com/YvanYin/Metric3D/assets/35299633/f95815ef-2506-4193-a6d9-1163ea821268)
+
+[//]: # (https://github.com/YvanYin/Metric3D/assets/35299633/ed00706c-41cc-49ea-accb-ad0532633cc2)
+
+[//]: # (### Zero-shot metric 3D recovery)
+
+[//]: # (https://github.com/YvanYin/Metric3D/assets/35299633/26cd7ae1-dd5a-4446-b275-54c5ca7ef945)
+
+[//]: # (https://github.com/YvanYin/Metric3D/assets/35299633/21e5484b-c304-4fe3-b1d3-8eebc4e26e42)
+[//]: # (### Monocular reconstruction for a Sequence)
+
+[//]: # ()
+[//]: # (### In-the-wild 3D reconstruction)
+
+[//]: # ()
+[//]: # (| | Image | Reconstruction | Pointcloud File |)
+
+[//]: # (|:---------:|:------------------:|:------------------:|:--------:|)
+
+[//]: # (| room |
|
| [Download](https://drive.google.com/file/d/1P1izSegH2c4LUrXGiUksw037PVb0hjZr/view?usp=drive_link) |)
+
+[//]: # (| Colosseum |
|
| [Download](https://drive.google.com/file/d/1jJCXe5IpxBhHDr0TZtNZhjxKTRUz56Hg/view?usp=drive_link) |)
+
+[//]: # (| chess |
|
| [Download](https://drive.google.com/file/d/1oV_Foq25_p-tTDRTcyO2AzXEdFJQz-Wm/view?usp=drive_link) |)
+
+[//]: # ()
+[//]: # (All three images are downloaded from [unplash](https://unsplash.com/) and put in the data/wild_demo directory.)
+
+[//]: # ()
+[//]: # (### 3D metric reconstruction, Metric3D Γ DroidSLAM)
+
+[//]: # (Metric3D can also provide scale information for DroidSLAM, help to solve the scale drift problem for better trajectories. )
+
+[//]: # ()
+[//]: # (#### Bird Eyes' View (Left: Droid-SLAM (mono). Right: Droid-SLAM with Metric-3D))
+
+[//]: # ()
+[//]: # (
)
+
+[//]: # (

)
+
+[//]: # (
)
+
+[//]: # ()
+[//]: # (### Front View)
+
+[//]: # ()
+[//]: # ()
+
+[//]: # (

)
+
+[//]: # (
)
+
+[//]: # ()
+[//]: # (#### KITTI odemetry evaluation (Translational RMS drift (t_rel, β) / Rotational RMS drift (r_rel, β)))
+
+[//]: # (| | Modality | seq 00 | seq 02 | seq 05 | seq 06 | seq 08 | seq 09 | seq 10 |)
+
+[//]: # (|:----------:|:--------:|:----------:|:----------:|:---------:|:----------:|:----------:|:---------:|:---------:|)
+
+[//]: # (| ORB-SLAM2 | Mono | 11.43/0.58 | 10.34/0.26 | 9.04/0.26 | 14.56/0.26 | 11.46/0.28 | 9.3/0.26 | 2.57/0.32 |)
+
+[//]: # (| Droid-SLAM | Mono | 33.9/0.29 | 34.88/0.27 | 23.4/0.27 | 17.2/0.26 | 39.6/0.31 | 21.7/0.23 | 7/0.25 |)
+
+[//]: # (| Droid+Ours | Mono | 1.44/0.37 | 2.64/0.29 | 1.44/0.25 | 0.6/0.2 | 2.2/0.3 | 1.63/0.22 | 2.73/0.23 |)
+
+[//]: # (| ORB-SLAM2 | Stereo | 0.88/0.31 | 0.77/0.28 | 0.62/0.26 | 0.89/0.27 | 1.03/0.31 | 0.86/0.25 | 0.62/0.29 |)
+
+[//]: # ()
+[//]: # (Metric3D makes the mono-SLAM scale-aware, like stereo systems.)
+
+[//]: # ()
+[//]: # (#### KITTI sequence videos - Youtube)
+
+[//]: # ([2011_09_30_drive_0028](https://youtu.be/gcTB4MgVCLQ) /)
+
+[//]: # ([2011_09_30_drive_0033](https://youtu.be/He581fmoPP4) /)
+
+[//]: # ([2011_09_30_drive_0034](https://youtu.be/I3PkukQ3_F8))
+
+[//]: # ()
+[//]: # (#### Estimated pose)
+
+[//]: # ([2011_09_30_drive_0033](https://drive.google.com/file/d/1SMXWzLYrEdmBe6uYMR9ShtDXeFDewChv/view?usp=drive_link) / )
+
+[//]: # ([2011_09_30_drive_0034](https://drive.google.com/file/d/1ONU4GxpvTlgW0TjReF1R2i-WFxbbjQPG/view?usp=drive_link) /)
+
+[//]: # ([2011_10_03_drive_0042](https://drive.google.com/file/d/19fweg6p1Q6TjJD2KlD7EMA_aV4FIeQUD/view?usp=drive_link))
+
+[//]: # ()
+[//]: # (#### Pointcloud files)
+
+[//]: # ([2011_09_30_drive_0033](https://drive.google.com/file/d/1K0o8DpUmLf-f_rue0OX1VaHlldpHBAfw/view?usp=drive_link) /)
+
+[//]: # ([2011_09_30_drive_0034](https://drive.google.com/file/d/1bvZ6JwMRyvi07H7Z2VD_0NX1Im8qraZo/view?usp=drive_link) /)
+
+[//]: # ([2011_10_03_drive_0042](https://drive.google.com/file/d/1Vw59F8nN5ApWdLeGKXvYgyS9SNKHKy4x/view?usp=drive_link))
+
+## π¨ Installation
+### One-line Installation
+For the ViT models, use the following environmentοΌ
+```bash
+pip install -r requirements_v2.txt
+```
+
+For ConvNeXt-L, it is
+```bash
+pip install -r requirements_v1.txt
+```
+
+### dataset annotation components
+With off-the-shelf depth datasets, we need to generate json annotaions in compatible with this dataset, which is organized by:
+```
+dict(
+ 'files':list(
+ dict(
+ 'rgb': 'data/kitti_demo/rgb/xxx.png',
+ 'depth': 'data/kitti_demo/depth/xxx.png',
+ 'depth_scale': 1000.0 # the depth scale of gt depth img.
+ 'cam_in': [fx, fy, cx, cy],
+ ),
+
+ dict(
+ ...
+ ),
+
+ ...
+ )
+)
+```
+To generate such annotations, please refer to the "Inference" section.
+
+### configs
+In ```mono/configs``` we provide different config setups.
+
+Intrinsics of the canonical camera is set bellow:
+```
+ canonical_space = dict(
+ img_size=(512, 960),
+ focal_length=1000.0,
+ ),
+```
+where cx and cy is set to be half of the image size.
+
+Inference settings are defined as
+```
+ depth_range=(0, 1),
+ depth_normalize=(0.3, 150),
+ crop_size = (512, 1088),
+```
+where the images will be first resized as the ```crop_size``` and then fed into the model.
+
+## βοΈ Inference
+### Download Checkpoint
+| | Encoder | Decoder | Link |
+|:----:|:-------------------:|:-----------------:|:-------------------------------------------------------------------------------------------------:|
+| v1-T | ConvNeXt-Tiny | Hourglass-Decoder | Coming soon |
+| v1-L | ConvNeXt-Large | Hourglass-Decoder | [Download](https://drive.google.com/file/d/1KVINiBkVpJylx_6z1lAC7CQ4kmn-RJRN/view?usp=drive_link) |
+| v2-S | DINO2reg-ViT-Small | RAFT-4iter | [Download](https://drive.google.com/file/d/1YfmvXwpWmhLg3jSxnhT7LvY0yawlXcr_/view?usp=drive_link) |
+| v2-L | DINO2reg-ViT-Large | RAFT-8iter | [Download](https://drive.google.com/file/d/1eT2gG-kwsVzNy5nJrbm4KC-9DbNKyLnr/view?usp=drive_link) |
+| v2-g | DINO2reg-ViT-giant2 | RAFT-8iter | Coming soon |
+
+### Dataset Mode
+1. put the trained ckpt file ```model.pth``` in ```weight/```.
+2. generate data annotation by following the code ```data/gene_annos_kitti_demo.py```, which includes 'rgb', (optional) 'intrinsic', (optional) 'depth', (optional) 'depth_scale'.
+3. change the 'test_data_path' in ```test_*.sh``` to the ```*.json``` path.
+4. run ```source test_kitti.sh``` or ```source test_nyu.sh```.
+
+### In-the-Wild Mode
+1. put the trained ckpt file ```model.pth``` in ```weight/```.
+2. change the 'test_data_path' in ```test.sh``` to the image folder path.
+3. run ```source test_vit.sh``` for transformers and ```source test.sh``` for convnets.
+As no intrinsics are provided, we provided by default 9 settings of focal length.
+
+## β Q & A
+### Q1: Why depth maps look good but pointclouds are distorted?
+Because the focal length is not properly set! Please find a proper focal length by modifying codes [here](mono/utils/do_test.py#309) yourself.
+
+### Q2: Why the pointclouds are too slow to be generated?
+Because the images are too large! Use smaller ones instead.
+
+### Q3: Why predicted depth maps are not satisfactory?
+First be sure all black padding regions at image boundaries are cropped out. Then please try again.
+Besides, metric 3D is not almighty. Some objects (chandeliers, drones...) / camera views (aerial view, bev...) do not occur frequently in the training datasets. We will going deeper into this and release more powerful solutions.
+
+## π§ Citation
+```
+@article{hu2024metric3dv2,
+ title={A Versatile Monocular Geometric Foundation Model for Zero-shot Metric Depth and Surface Normal Estimation},
+ author={Hu, Mu and Yin, Wei, and Zhang, Chi and Cai, Zhipeng and Long, Xiaoxiao and Chen, Hao, and Wang, Kaixuan and Yu, Gang and Shen, Chunhua and Shen, Shaojie},
+ booktitle={arXiv},
+ year={2024}
+}
+```
+```
+@article{yin2023metric,
+ title={Metric3D: Towards Zero-shot Metric 3D Prediction from A Single Image},
+ author={Wei Yin, Chi Zhang, Hao Chen, Zhipeng Cai, Gang Yu, Kaixuan Wang, Xiaozhi Chen, Chunhua Shen},
+ booktitle={ICCV},
+ year={2023}
+}
+```
+
+## License and Contact
+
+The *Metric 3D* code is under a 2-clause BSD License for non-commercial usage. For further questions, contact Dr. yvan.yin [yvanwy@outlook.com] and Mr. mu.hu [mhuam@connect.ust.hk].
diff --git a/data/gene_annos_kitti_demo.py b/data/gene_annos_kitti_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..9edf1b4910c28593056dc03020d02674a70b522b
--- /dev/null
+++ b/data/gene_annos_kitti_demo.py
@@ -0,0 +1,32 @@
+if __name__=='__main__':
+ import os
+ import os.path as osp
+ import numpy as np
+ import cv2
+ import json
+
+ code_root = '/mnt/nas/share/home/xugk/MetricDepth_test/'
+
+ data_root = osp.join(code_root, 'data/kitti_demo')
+ split_root = code_root
+
+ files = []
+ rgb_root = osp.join(data_root, 'rgb')
+ depth_root = osp.join(data_root, 'depth')
+ for rgb_file in os.listdir(rgb_root):
+ rgb_path = osp.join(rgb_root, rgb_file).split(split_root)[-1]
+ depth_path = rgb_path.replace('/rgb/', '/depth/')
+ cam_in = [707.0493, 707.0493, 604.0814, 180.5066]
+ depth_scale = 256.
+
+ meta_data = {}
+ meta_data['cam_in'] = cam_in
+ meta_data['rgb'] = rgb_path
+ meta_data['depth'] = depth_path
+ meta_data['depth_scale'] = depth_scale
+ files.append(meta_data)
+ files_dict = dict(files=files)
+
+ with open(osp.join(code_root, 'data/kitti_demo/test_annotations.json'), 'w') as f:
+ json.dump(files_dict, f)
+
\ No newline at end of file
diff --git a/data/gene_annos_nyu_demo.py b/data/gene_annos_nyu_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c1a4e815ed882d4d513c407d3c0b718d5ec6d7
--- /dev/null
+++ b/data/gene_annos_nyu_demo.py
@@ -0,0 +1,31 @@
+if __name__=='__main__':
+ import os
+ import os.path as osp
+ import numpy as np
+ import cv2
+ import json
+
+ code_root = '/mnt/nas/share/home/xugk/MetricDepth_test/'
+
+ data_root = osp.join(code_root, 'data/nyu_demo')
+ split_root = code_root
+
+ files = []
+ rgb_root = osp.join(data_root, 'rgb')
+ depth_root = osp.join(data_root, 'depth')
+ for rgb_file in os.listdir(rgb_root):
+ rgb_path = osp.join(rgb_root, rgb_file).split(split_root)[-1]
+ depth_path = rgb_path.replace('.jpg', '.png').replace('/rgb_', '/sync_depth_').replace('/rgb/', '/depth/')
+ cam_in = [518.8579, 519.46961, 325.58245, 253.73617]
+ depth_scale = 1000.
+
+ meta_data = {}
+ meta_data['cam_in'] = cam_in
+ meta_data['rgb'] = rgb_path
+ meta_data['depth'] = depth_path
+ meta_data['depth_scale'] = depth_scale
+ files.append(meta_data)
+ files_dict = dict(files=files)
+
+ with open(osp.join(code_root, 'data/nyu_demo/test_annotations.json'), 'w') as f:
+ json.dump(files_dict, f)
\ No newline at end of file
diff --git a/data/kitti_demo/depth/0000000005.png b/data/kitti_demo/depth/0000000005.png
new file mode 100644
index 0000000000000000000000000000000000000000..37c81e35db5b9d57680b26d1e3dfb14fcea68be3
--- /dev/null
+++ b/data/kitti_demo/depth/0000000005.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eb0d83fc93bcf235384c690ae405e0b24b3bfc6a05e1220a4c902bed3b5ba113
+size 191967
diff --git a/data/kitti_demo/depth/0000000050.png b/data/kitti_demo/depth/0000000050.png
new file mode 100644
index 0000000000000000000000000000000000000000..395eba26aeb29fc7729df1e221adaaee183696a2
--- /dev/null
+++ b/data/kitti_demo/depth/0000000050.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3eef554b3b312829e7d1e76a1acd13e7261024eb3c4d6e176328be377ff9216e
+size 200646
diff --git a/data/kitti_demo/depth/0000000100.png b/data/kitti_demo/depth/0000000100.png
new file mode 100644
index 0000000000000000000000000000000000000000..4c06323540561465cd66bf91d871789fdc8291c7
--- /dev/null
+++ b/data/kitti_demo/depth/0000000100.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4b7e9c85e2b4f8131019fe93e0c1cf36f5058b30d040998a8199c4bb2d97e9b1
+size 181743
diff --git a/data/kitti_demo/rgb/0000000005.png b/data/kitti_demo/rgb/0000000005.png
new file mode 100644
index 0000000000000000000000000000000000000000..89592167d59e65cb87478890a1d870e7e23fcdc7
--- /dev/null
+++ b/data/kitti_demo/rgb/0000000005.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a9754dcadc8b3ace31a368500af3e382e2c0763242a7b054d424650cec67646a
+size 872928
diff --git a/data/kitti_demo/rgb/0000000050.png b/data/kitti_demo/rgb/0000000050.png
new file mode 100644
index 0000000000000000000000000000000000000000..19b8fc027b2de753c089b728c6162d25b3e59e0e
--- /dev/null
+++ b/data/kitti_demo/rgb/0000000050.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:19e4f8f377521c8e28aca9addf2b695f9e374e5f44ee38d58970d12a21fbc4bf
+size 873924
diff --git a/data/kitti_demo/rgb/0000000100.png b/data/kitti_demo/rgb/0000000100.png
new file mode 100644
index 0000000000000000000000000000000000000000..475f4e4be43091dbc3669f8c5b2a22ecd6c961e9
--- /dev/null
+++ b/data/kitti_demo/rgb/0000000100.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1f216c6fa51fb640c6cfb8a16cc91f60b20b1d2775def3d86c52c2bba1388365
+size 916166
diff --git a/data/kitti_demo/test_annotations.json b/data/kitti_demo/test_annotations.json
new file mode 100644
index 0000000000000000000000000000000000000000..0153ec662a98b6c921eddbdb87132013d69111c9
--- /dev/null
+++ b/data/kitti_demo/test_annotations.json
@@ -0,0 +1 @@
+{"files": [{"cam_in": [707.0493, 707.0493, 604.0814, 180.5066], "rgb": "data/kitti_demo/rgb/0000000050.png", "depth": "data/kitti_demo/depth/0000000050.png", "depth_scale": 256.0}, {"cam_in": [707.0493, 707.0493, 604.0814, 180.5066], "rgb": "data/kitti_demo/rgb/0000000100.png", "depth": "data/kitti_demo/depth/0000000100.png", "depth_scale": 256.0}, {"cam_in": [707.0493, 707.0493, 604.0814, 180.5066], "rgb": "data/kitti_demo/rgb/0000000005.png", "depth": "data/kitti_demo/depth/0000000005.png", "depth_scale": 256.0}]}
\ No newline at end of file
diff --git a/data/nyu_demo/depth/sync_depth_00000.png b/data/nyu_demo/depth/sync_depth_00000.png
new file mode 100644
index 0000000000000000000000000000000000000000..d1157d8388bfed3bbb0b1b2eb3e05e124d87f9c7
--- /dev/null
+++ b/data/nyu_demo/depth/sync_depth_00000.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:043e9c8bee7af97afff01e451da3f5e9cd1591995f415944dd0dc91036a35b5a
+size 166196
diff --git a/data/nyu_demo/depth/sync_depth_00050.png b/data/nyu_demo/depth/sync_depth_00050.png
new file mode 100644
index 0000000000000000000000000000000000000000..2b7d8857e55727b382f64f9958325b1762e067aa
--- /dev/null
+++ b/data/nyu_demo/depth/sync_depth_00050.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:53c764e869f61cf4240586395bc7374dcc02e65b8442801b53b74ffa563d30fe
+size 182376
diff --git a/data/nyu_demo/depth/sync_depth_00100.png b/data/nyu_demo/depth/sync_depth_00100.png
new file mode 100644
index 0000000000000000000000000000000000000000..7b7e7e77298ad11bbd6c157f145aca06637f07e6
--- /dev/null
+++ b/data/nyu_demo/depth/sync_depth_00100.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dc0c16d56bfdcc958f37fa28bcf39b110a14c317bfe3c221b3c3bc6d73dec67d
+size 141576
diff --git a/data/nyu_demo/rgb/rgb_00000.jpg b/data/nyu_demo/rgb/rgb_00000.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..af64697bccaf1d017d105e8eb407a1ff95b1ae4e
Binary files /dev/null and b/data/nyu_demo/rgb/rgb_00000.jpg differ
diff --git a/data/nyu_demo/rgb/rgb_00050.jpg b/data/nyu_demo/rgb/rgb_00050.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0712c42a51187f5e1d34e74f4d98f1489ba9251f
Binary files /dev/null and b/data/nyu_demo/rgb/rgb_00050.jpg differ
diff --git a/data/nyu_demo/rgb/rgb_00100.jpg b/data/nyu_demo/rgb/rgb_00100.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f5388677b4c9d2ad5e083b0753ae14193c7aaf48
Binary files /dev/null and b/data/nyu_demo/rgb/rgb_00100.jpg differ
diff --git a/data/nyu_demo/test_annotations.json b/data/nyu_demo/test_annotations.json
new file mode 100644
index 0000000000000000000000000000000000000000..806fe4635d30b0f810ae5df568365e76551dc7c9
--- /dev/null
+++ b/data/nyu_demo/test_annotations.json
@@ -0,0 +1 @@
+{"files": [{"cam_in": [518.8579, 519.46961, 325.58245, 253.73617], "rgb": "data/nyu_demo/rgb/rgb_00000.jpg", "depth": "data/nyu_demo/depth/sync_depth_00000.png", "depth_scale": 1000.0}, {"cam_in": [518.8579, 519.46961, 325.58245, 253.73617], "rgb": "data/nyu_demo/rgb/rgb_00050.jpg", "depth": "data/nyu_demo/depth/sync_depth_00050.png", "depth_scale": 1000.0}, {"cam_in": [518.8579, 519.46961, 325.58245, 253.73617], "rgb": "data/nyu_demo/rgb/rgb_00100.jpg", "depth": "data/nyu_demo/depth/sync_depth_00100.png", "depth_scale": 1000.0}]}
\ No newline at end of file
diff --git a/data/wild_demo/david-kohler-VFRTXGw1VjU-unsplash.jpg b/data/wild_demo/david-kohler-VFRTXGw1VjU-unsplash.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..4703e2931365f128562080f1857b6efeb07fc380
Binary files /dev/null and b/data/wild_demo/david-kohler-VFRTXGw1VjU-unsplash.jpg differ
diff --git a/data/wild_demo/jonathan-borba-CnthDZXCdoY-unsplash.jpg b/data/wild_demo/jonathan-borba-CnthDZXCdoY-unsplash.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c7905c09e7334493622faec304dc04f338aee898
Binary files /dev/null and b/data/wild_demo/jonathan-borba-CnthDZXCdoY-unsplash.jpg differ
diff --git a/data/wild_demo/randy-fath-G1yhU1Ej-9A-unsplash.jpg b/data/wild_demo/randy-fath-G1yhU1Ej-9A-unsplash.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c39e76927455d8e972874c2af6781e3c836dc313
Binary files /dev/null and b/data/wild_demo/randy-fath-G1yhU1Ej-9A-unsplash.jpg differ
diff --git a/data_info/__init__.py b/data_info/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec8374be5bc1a77bc72386ebf46cb50154217684
--- /dev/null
+++ b/data_info/__init__.py
@@ -0,0 +1,2 @@
+from .public_datasets import *
+from .pretrained_weight import *
\ No newline at end of file
diff --git a/data_info/pretrained_weight.py b/data_info/pretrained_weight.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e7226f02c486942e715f8f8b7d0287809bd451c
--- /dev/null
+++ b/data_info/pretrained_weight.py
@@ -0,0 +1,16 @@
+
+mldb_info={}
+
+mldb_info['checkpoint']={
+ 'mldb_root': '/mnt/nas/share/home/xugk/ckpt', # NOTE: modify it to the pretrained ckpt root
+
+ # pretrained weight for convnext
+ 'convnext_tiny': 'convnext/convnext_tiny_22k_1k_384.pth',
+ 'convnext_small': 'convnext/convnext_small_22k_1k_384.pth',
+ 'convnext_base': 'convnext/convnext_base_22k_1k_384.pth',
+ 'convnext_large': 'convnext/convnext_large_22k_1k_384.pth',
+ 'vit_large': 'vit/dinov2_vitl14_pretrain.pth',
+ 'vit_small_reg': 'vit/dinov2_vits14_reg4_pretrain.pth',
+ 'vit_large_reg': 'vit/dinov2_vitl14_reg4_pretrain.pth',
+ 'vit_giant2_reg': 'vit/dinov2_vitg14_reg4_pretrain.pth',
+}
\ No newline at end of file
diff --git a/data_info/public_datasets.py b/data_info/public_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6b67618fff864bfc42cbf5609e17ad2e041b05d
--- /dev/null
+++ b/data_info/public_datasets.py
@@ -0,0 +1,7 @@
+mldb_info = {}
+
+mldb_info['NYU']={
+ 'mldb_root': '/mnt/nas/share/home/xugk/data/',
+ 'data_root': 'nyu',
+ 'test_annotations_path': 'nyu/test_annotation.json',
+}
diff --git a/media/gifs/demo_1.gif b/media/gifs/demo_1.gif
new file mode 100644
index 0000000000000000000000000000000000000000..778f5293416d638c6918206e8b450a6ec5f1ec2c
--- /dev/null
+++ b/media/gifs/demo_1.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f07ee050ca8b76991966f45bb74eae6e61e6b11eeb9466b524c6ab5164711d36
+size 10693260
diff --git a/media/gifs/demo_12.gif b/media/gifs/demo_12.gif
new file mode 100644
index 0000000000000000000000000000000000000000..dbed9296e787d0c91844b83c2e51a5357395b1d9
--- /dev/null
+++ b/media/gifs/demo_12.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c1886d6dff7714d015e6b7c004d88d3014057b1cefe1dc5544fa9bedb81383bc
+size 9414756
diff --git a/media/gifs/demo_2.gif b/media/gifs/demo_2.gif
new file mode 100644
index 0000000000000000000000000000000000000000..f1e0c7476553812836ef197804776c47062985d6
--- /dev/null
+++ b/media/gifs/demo_2.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d11e3f9a11374166fc363a3fed17928957de546a548eccc4c7efa4d9317cf4c5
+size 9023151
diff --git a/media/gifs/demo_22.gif b/media/gifs/demo_22.gif
new file mode 100644
index 0000000000000000000000000000000000000000..6a093b3ec4deac3034bfd229e9bcfdbb0240cd25
--- /dev/null
+++ b/media/gifs/demo_22.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c56b0785a5991126d02b349f8801980f31b2ef7b661cad07be4888ff42dc29d0
+size 6390996
diff --git a/media/screenshots/challenge.PNG b/media/screenshots/challenge.PNG
new file mode 100644
index 0000000000000000000000000000000000000000..ccff81751639620f7e9c8ab4aabca53b4bb5e7b2
Binary files /dev/null and b/media/screenshots/challenge.PNG differ
diff --git a/media/screenshots/page2.png b/media/screenshots/page2.png
new file mode 100644
index 0000000000000000000000000000000000000000..8d77ac08d1488d1e1f39d232247b15f35fc26ff3
--- /dev/null
+++ b/media/screenshots/page2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c46a332e0f9f868c767f65f70c0fa11ec4f7da2dfe69d47046dff5c37964c171
+size 4347474
diff --git a/media/screenshots/pipeline.png b/media/screenshots/pipeline.png
new file mode 100644
index 0000000000000000000000000000000000000000..ec566b347f5628fa4bd53cdd38d5549902f68eee
--- /dev/null
+++ b/media/screenshots/pipeline.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:19a7b36e83761aae0ecd27e1215e31fded8c9ef3d308734e690456921703f662
+size 398892
diff --git a/mono/configs/HourglassDecoder/convlarge.0.3_150.py b/mono/configs/HourglassDecoder/convlarge.0.3_150.py
new file mode 100644
index 0000000000000000000000000000000000000000..37b91c80284d6db3df3017ec636f18198e42dc08
--- /dev/null
+++ b/mono/configs/HourglassDecoder/convlarge.0.3_150.py
@@ -0,0 +1,25 @@
+_base_=[
+ '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py',
+ '../_base_/datasets/_data_base_.py',
+ '../_base_/default_runtime.py',
+ ]
+
+model = dict(
+ backbone=dict(
+ pretrained=False,
+ )
+)
+
+# configs of the canonical space
+data_basic=dict(
+ canonical_space = dict(
+ img_size=(512, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.3, 150),
+ crop_size = (544, 1216),
+)
+
+batchsize_per_gpu = 2
+thread_per_gpu = 4
diff --git a/mono/configs/HourglassDecoder/test_kitti_convlarge.0.3_150.py b/mono/configs/HourglassDecoder/test_kitti_convlarge.0.3_150.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdd9156b7f2f0921fb01b1adaf9a2a7447332d6e
--- /dev/null
+++ b/mono/configs/HourglassDecoder/test_kitti_convlarge.0.3_150.py
@@ -0,0 +1,25 @@
+_base_=[
+ '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py',
+ '../_base_/datasets/_data_base_.py',
+ '../_base_/default_runtime.py',
+ ]
+
+model = dict(
+ backbone=dict(
+ pretrained=False,
+ )
+)
+
+# configs of the canonical space
+data_basic=dict(
+ canonical_space = dict(
+ img_size=(512, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.3, 150),
+ crop_size = (512, 1088),
+)
+
+batchsize_per_gpu = 2
+thread_per_gpu = 4
diff --git a/mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py b/mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py
new file mode 100644
index 0000000000000000000000000000000000000000..6601f5cdfad07c5fad8b89fbf959e67039126dfa
--- /dev/null
+++ b/mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py
@@ -0,0 +1,25 @@
+_base_=[
+ '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py',
+ '../_base_/datasets/_data_base_.py',
+ '../_base_/default_runtime.py',
+ ]
+
+model = dict(
+ backbone=dict(
+ pretrained=False,
+ )
+)
+
+# configs of the canonical space
+data_basic=dict(
+ canonical_space = dict(
+ img_size=(512, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.3, 150),
+ crop_size = (480, 1216),
+)
+
+batchsize_per_gpu = 2
+thread_per_gpu = 4
diff --git a/mono/configs/HourglassDecoder/vit.raft5.large.py b/mono/configs/HourglassDecoder/vit.raft5.large.py
new file mode 100644
index 0000000000000000000000000000000000000000..4febdcb2867513008496f394ce8dc513230fb480
--- /dev/null
+++ b/mono/configs/HourglassDecoder/vit.raft5.large.py
@@ -0,0 +1,33 @@
+_base_=[
+ '../_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py',
+ '../_base_/datasets/_data_base_.py',
+ '../_base_/default_runtime.py',
+ ]
+
+import numpy as np
+model=dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+
+max_value = 200
+# configs of the canonical space
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, max_value),
+ crop_size = (616, 1064), # %28 = 0
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064)
+)
+
+batchsize_per_gpu = 1
+thread_per_gpu = 1
diff --git a/mono/configs/HourglassDecoder/vit.raft5.small.py b/mono/configs/HourglassDecoder/vit.raft5.small.py
new file mode 100644
index 0000000000000000000000000000000000000000..25eb68cc151f090c7654b7ebbcaf9dfc6a478570
--- /dev/null
+++ b/mono/configs/HourglassDecoder/vit.raft5.small.py
@@ -0,0 +1,33 @@
+_base_=[
+ '../_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py',
+ '../_base_/datasets/_data_base_.py',
+ '../_base_/default_runtime.py',
+ ]
+
+import numpy as np
+model=dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=4,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+
+max_value = 200
+# configs of the canonical space
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, max_value),
+ crop_size = (616, 1064), # %28 = 0
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064)
+)
+
+batchsize_per_gpu = 1
+thread_per_gpu = 1
diff --git a/mono/configs/__init__.py b/mono/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/mono/configs/__init__.py
@@ -0,0 +1 @@
+
diff --git a/mono/configs/_base_/_data_base_.py b/mono/configs/_base_/_data_base_.py
new file mode 100644
index 0000000000000000000000000000000000000000..35f3844f24191b6b9452e136ea3205b7622466d7
--- /dev/null
+++ b/mono/configs/_base_/_data_base_.py
@@ -0,0 +1,13 @@
+# canonical camera setting and basic data setting
+# we set it same as the E300 camera (crop version)
+#
+data_basic=dict(
+ canonical_space = dict(
+ img_size=(540, 960),
+ focal_length=1196.0,
+ ),
+ depth_range=(0.9, 150),
+ depth_normalize=(0.006, 1.001),
+ crop_size = (512, 960),
+ clip_depth_range=(0.9, 150),
+)
diff --git a/mono/configs/_base_/datasets/_data_base_.py b/mono/configs/_base_/datasets/_data_base_.py
new file mode 100644
index 0000000000000000000000000000000000000000..b554444e9b75b4519b862e726890dcf7859be0ec
--- /dev/null
+++ b/mono/configs/_base_/datasets/_data_base_.py
@@ -0,0 +1,12 @@
+# canonical camera setting and basic data setting
+#
+data_basic=dict(
+ canonical_space = dict(
+ img_size=(540, 960),
+ focal_length=1196.0,
+ ),
+ depth_range=(0.9, 150),
+ depth_normalize=(0.006, 1.001),
+ crop_size = (512, 960),
+ clip_depth_range=(0.9, 150),
+)
diff --git a/mono/configs/_base_/default_runtime.py b/mono/configs/_base_/default_runtime.py
new file mode 100644
index 0000000000000000000000000000000000000000..a690b491bf50aad5c2fd7e9ac387609123a4594a
--- /dev/null
+++ b/mono/configs/_base_/default_runtime.py
@@ -0,0 +1,4 @@
+
+load_from = None
+cudnn_benchmark = True
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3','rmse_log', 'log10', 'sq_rel']
diff --git a/mono/configs/_base_/models/backbones/convnext_large.py b/mono/configs/_base_/models/backbones/convnext_large.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a22f7e1b53ca154bfae1672e6ee3b52028039b9
--- /dev/null
+++ b/mono/configs/_base_/models/backbones/convnext_large.py
@@ -0,0 +1,16 @@
+#_base_ = ['./_model_base_.py',]
+
+#'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-large_3rdparty_in21k_20220301-e6e0ea0a.pth'
+model = dict(
+ #type='EncoderDecoderAuxi',
+ backbone=dict(
+ type='convnext_large',
+ pretrained=True,
+ in_22k=True,
+ out_indices=[0, 1, 2, 3],
+ drop_path_rate=0.4,
+ layer_scale_init_value=1.0,
+ checkpoint='data/pretrained_weight_repo/convnext/convnext_large_22k_1k_384.pth',
+ prefix='backbones.',
+ out_channels=[192, 384, 768, 1536]),
+ )
diff --git a/mono/configs/_base_/models/backbones/dino_vit_large.py b/mono/configs/_base_/models/backbones/dino_vit_large.py
new file mode 100644
index 0000000000000000000000000000000000000000..843178ed6e61d74070b971f01148f87fdf2a62cf
--- /dev/null
+++ b/mono/configs/_base_/models/backbones/dino_vit_large.py
@@ -0,0 +1,7 @@
+model = dict(
+ backbone=dict(
+ type='vit_large',
+ prefix='backbones.',
+ out_channels=[1024, 1024, 1024, 1024],
+ drop_path_rate = 0.0),
+ )
diff --git a/mono/configs/_base_/models/backbones/dino_vit_large_reg.py b/mono/configs/_base_/models/backbones/dino_vit_large_reg.py
new file mode 100644
index 0000000000000000000000000000000000000000..25e96747d459d42df299f8a6a1e14044a0e56164
--- /dev/null
+++ b/mono/configs/_base_/models/backbones/dino_vit_large_reg.py
@@ -0,0 +1,7 @@
+model = dict(
+ backbone=dict(
+ type='vit_large_reg',
+ prefix='backbones.',
+ out_channels=[1024, 1024, 1024, 1024],
+ drop_path_rate = 0.0),
+ )
diff --git a/mono/configs/_base_/models/backbones/dino_vit_small_reg.py b/mono/configs/_base_/models/backbones/dino_vit_small_reg.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c8bd97dccb9cdee7517250f40e01bb3124144e6
--- /dev/null
+++ b/mono/configs/_base_/models/backbones/dino_vit_small_reg.py
@@ -0,0 +1,7 @@
+model = dict(
+ backbone=dict(
+ type='vit_small_reg',
+ prefix='backbones.',
+ out_channels=[384, 384, 384, 384],
+ drop_path_rate = 0.0),
+ )
diff --git a/mono/configs/_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py b/mono/configs/_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f262288c49e7ffccb6174b09b0daf80ff79dd684
--- /dev/null
+++ b/mono/configs/_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py
@@ -0,0 +1,10 @@
+# model settings
+_base_ = ['../backbones/convnext_large.py',]
+model = dict(
+ type='DensePredModel',
+ decode_head=dict(
+ type='HourglassDecoder',
+ in_channels=[192, 384, 768, 1536],
+ decoder_channel=[128, 128, 256, 512],
+ prefix='decode_heads.'),
+)
diff --git a/mono/configs/_base_/models/encoder_decoder/dino_vit_large.dpt_raft.py b/mono/configs/_base_/models/encoder_decoder/dino_vit_large.dpt_raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd69efefab2c03de435996c6b7b65ff941db1e5d
--- /dev/null
+++ b/mono/configs/_base_/models/encoder_decoder/dino_vit_large.dpt_raft.py
@@ -0,0 +1,20 @@
+# model settings
+_base_ = ['../backbones/dino_vit_large.py']
+model = dict(
+ type='DensePredModel',
+ decode_head=dict(
+ type='RAFTDepthDPT',
+ in_channels=[1024, 1024, 1024, 1024],
+ use_cls_token=True,
+ feature_channels = [256, 512, 1024, 1024], # [2/7, 1/7, 1/14, 1/14]
+ decoder_channels = [128, 256, 512, 1024, 1024], # [4/7, 2/7, 1/7, 1/14, 1/14]
+ up_scale = 7,
+ hidden_channels=[128, 128, 128, 128], # [x_4, x_8, x_16, x_32] [192, 384, 768, 1536]
+ n_gru_layers=3,
+ n_downsample=2,
+ iters=12,
+ slow_fast_gru=True,
+ corr_radius=4,
+ corr_levels=4,
+ prefix='decode_heads.'),
+)
diff --git a/mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py b/mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..26ab6dc090e9cdb840d84fab10587becb536dbb8
--- /dev/null
+++ b/mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py
@@ -0,0 +1,19 @@
+# model settings
+_base_ = ['../backbones/dino_vit_large_reg.py']
+model = dict(
+ type='DensePredModel',
+ decode_head=dict(
+ type='RAFTDepthDPT',
+ in_channels=[1024, 1024, 1024, 1024],
+ use_cls_token=True,
+ feature_channels = [256, 512, 1024, 1024], # [2/7, 1/7, 1/14, 1/14]
+ decoder_channels = [128, 256, 512, 1024, 1024], # [4/7, 2/7, 1/7, 1/14, 1/14]
+ up_scale = 7,
+ hidden_channels=[128, 128, 128, 128], # [x_4, x_8, x_16, x_32] [192, 384, 768, 1536]
+ n_gru_layers=3,
+ n_downsample=2,
+ iters=3,
+ slow_fast_gru=True,
+ num_register_tokens=4,
+ prefix='decode_heads.'),
+)
diff --git a/mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py b/mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..19466c191e9f2a83903e55ca4fc0827d9a11bcb9
--- /dev/null
+++ b/mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py
@@ -0,0 +1,19 @@
+# model settings
+_base_ = ['../backbones/dino_vit_small_reg.py']
+model = dict(
+ type='DensePredModel',
+ decode_head=dict(
+ type='RAFTDepthDPT',
+ in_channels=[384, 384, 384, 384],
+ use_cls_token=True,
+ feature_channels = [96, 192, 384, 768], # [2/7, 1/7, 1/14, 1/14]
+ decoder_channels = [48, 96, 192, 384, 384], # [-, 1/4, 1/7, 1/14, 1/14]
+ up_scale = 7,
+ hidden_channels=[48, 48, 48, 48], # [x_4, x_8, x_16, x_32] [1/4, 1/7, 1/14, -]
+ n_gru_layers=3,
+ n_downsample=2,
+ iters=3,
+ slow_fast_gru=True,
+ num_register_tokens=4,
+ prefix='decode_heads.'),
+)
diff --git a/mono/model/__init__.py b/mono/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e1ea3d3e3b880e28ef880083b3c79e3b00cd119
--- /dev/null
+++ b/mono/model/__init__.py
@@ -0,0 +1,5 @@
+from .monodepth_model import DepthModel
+# from .__base_model__ import BaseDepthModel
+
+
+__all__ = ['DepthModel', 'BaseDepthModel']
diff --git a/mono/model/backbones/ConvNeXt.py b/mono/model/backbones/ConvNeXt.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1c4be0e6463ae2b0dda6d20fc273a300afa5ebf
--- /dev/null
+++ b/mono/model/backbones/ConvNeXt.py
@@ -0,0 +1,271 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.models.layers import trunc_normal_, DropPath
+from timm.models.registry import register_model
+
+class Block(nn.Module):
+ r""" ConvNeXt Block. There are two equivalent implementations:
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+ We use (2) as we find it slightly faster in PyTorch
+
+ Args:
+ dim (int): Number of input channels.
+ drop_path (float): Stochastic depth rate. Default: 0.0
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ """
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
+ super().__init__()
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
+ self.norm = LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(4 * dim, dim)
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
+ requires_grad=True) if layer_scale_init_value > 0 else None
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x):
+ input = x
+ x = self.dwconv(x)
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+ x = input + self.drop_path(x)
+ return x
+
+class ConvNeXt(nn.Module):
+ r""" ConvNeXt
+ A PyTorch impl of : `A ConvNet for the 2020s` -
+ https://arxiv.org/pdf/2201.03545.pdf
+ Args:
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
+ """
+ def __init__(self, in_chans=3, num_classes=1000,
+ depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
+ layer_scale_init_value=1e-6, head_init_scale=1.,
+ **kwargs,):
+ super().__init__()
+
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
+ stem = nn.Sequential(
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
+ )
+ self.downsample_layers.append(stem)
+ for i in range(3):
+ downsample_layer = nn.Sequential(
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
+ )
+ self.downsample_layers.append(downsample_layer)
+
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
+ cur = 0
+ for i in range(4):
+ stage = nn.Sequential(
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
+ )
+ self.stages.append(stage)
+ cur += depths[i]
+
+ #self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
+ #self.head = nn.Linear(dims[-1], num_classes)
+
+ self.apply(self._init_weights)
+ #self.head.weight.data.mul_(head_init_scale)
+ #self.head.bias.data.mul_(head_init_scale)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ trunc_normal_(m.weight, std=.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward_features(self, x):
+ features = []
+ for i in range(4):
+ x = self.downsample_layers[i](x)
+ x = self.stages[i](x)
+ features.append(x)
+ return features # global average pooling, (N, C, H, W) -> (N, C)
+
+ def forward(self, x):
+ #x = self.forward_features(x)
+ #x = self.head(x)
+ features = self.forward_features(x)
+ return features
+
+class LayerNorm(nn.Module):
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+ with shape (batch_size, channels, height, width).
+ """
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.data_format = data_format
+ if self.data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError
+ self.normalized_shape = (normalized_shape, )
+
+ def forward(self, x):
+ if self.data_format == "channels_last":
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+ elif self.data_format == "channels_first":
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
+
+
+model_urls = {
+ "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
+ "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
+ "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
+ "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
+ "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
+ "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
+ "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
+ "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
+ "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
+}
+
+def convnext_tiny(pretrained=True,in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
+ if pretrained:
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
+ #url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
+ #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
+ model_dict = model.state_dict()
+ pretrained_dict = {}
+ unmatched_pretrained_dict = {}
+ for k, v in checkpoint['model'].items():
+ if k in model_dict:
+ pretrained_dict[k] = v
+ else:
+ unmatched_pretrained_dict[k] = v
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict)
+ print(
+ 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
+ print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
+ return model
+
+def convnext_small(pretrained=True,in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
+ if pretrained:
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
+ #url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
+ #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
+ model_dict = model.state_dict()
+ pretrained_dict = {}
+ unmatched_pretrained_dict = {}
+ for k, v in checkpoint['model'].items():
+ if k in model_dict:
+ pretrained_dict[k] = v
+ else:
+ unmatched_pretrained_dict[k] = v
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict)
+ print(
+ 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
+ print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
+ return model
+
+def convnext_base(pretrained=True, in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
+ if pretrained:
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
+ #url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
+ #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
+ model_dict = model.state_dict()
+ pretrained_dict = {}
+ unmatched_pretrained_dict = {}
+ for k, v in checkpoint['model'].items():
+ if k in model_dict:
+ pretrained_dict[k] = v
+ else:
+ unmatched_pretrained_dict[k] = v
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict)
+ print(
+ 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
+ print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
+ return model
+
+def convnext_large(pretrained=True, in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
+ if pretrained:
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
+ #url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
+ #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
+ model_dict = model.state_dict()
+ pretrained_dict = {}
+ unmatched_pretrained_dict = {}
+ for k, v in checkpoint['model'].items():
+ if k in model_dict:
+ pretrained_dict[k] = v
+ else:
+ unmatched_pretrained_dict[k] = v
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict)
+ print(
+ 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
+ print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
+ return model
+
+def convnext_xlarge(pretrained=True, in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
+ if pretrained:
+ assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
+ #url = model_urls['convnext_xlarge_22k']
+ #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
+ model_dict = model.state_dict()
+ pretrained_dict = {}
+ unmatched_pretrained_dict = {}
+ for k, v in checkpoint['model'].items():
+ if k in model_dict:
+ pretrained_dict[k] = v
+ else:
+ unmatched_pretrained_dict[k] = v
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict)
+ print(
+ 'Successfully loaded pretrained %d params, and %d paras are unmatched.'
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
+ print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys())
+ return model
+
+if __name__ == '__main__':
+ import torch
+ model = convnext_base(True, in_22k=False).cuda()
+
+ rgb = torch.rand((2, 3, 256, 256)).cuda()
+ out = model(rgb)
+ print(len(out))
+ for i, ft in enumerate(out):
+ print(i, ft.shape)
diff --git a/mono/model/backbones/ViT_DINO.py b/mono/model/backbones/ViT_DINO.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a1998f0dd5024fbe69895e244fc054245a06568
--- /dev/null
+++ b/mono/model/backbones/ViT_DINO.py
@@ -0,0 +1,1504 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable, Optional, Dict, Any, List
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+#from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+logger = logging.getLogger("dinov2")
+
+class ConvBlock(nn.Module):
+ def __init__(self, channels):
+ super(ConvBlock, self).__init__()
+
+ self.act = nn.ReLU(inplace=True)
+ self.conv1 = nn.Conv2d(
+ channels,
+ channels,
+ kernel_size=3,
+ stride=1,
+ padding=1
+ )
+ self.norm1 = nn.BatchNorm2d(channels)
+ self.conv2 = nn.Conv2d(
+ channels,
+ channels,
+ kernel_size=3,
+ stride=1,
+ padding=1
+ )
+ self.norm2 = nn.BatchNorm2d(channels)
+
+ def forward(self, x):
+
+ out = self.norm1(x)
+ out = self.act(out)
+ out = self.conv1(out)
+ out = self.norm2(out)
+ out = self.act(out)
+ out = self.conv2(out)
+ return x + out
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+try:
+ from xformers.ops import SwiGLU
+ #import numpy.bool
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
+
+
+try:
+ from xformers.ops import memory_efficient_attention, unbind, fmha
+ from xformers.components.attention import ScaledDotProduct
+ from xformers.components import MultiHeadDispatch
+ #import numpy.bool
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ window_size: int = 0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ #if not self.training:
+ #
+ # self.attn = ScaledDotProduct()
+ #self.attn = MultiHeadDispatch(dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, attention=attn)
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ if attn_bias is not None:
+ attn = attn + attn_bias[:, :, :N]
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ #if True:
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
+ return super().forward(x, attn_bias)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+ if attn_bias is not None:
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias[:, :, :N])
+ else:
+ x = memory_efficient_attention(q, k, v)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+try:
+ from xformers.ops import fmha
+ from xformers.ops import scaled_index_add, index_select_cat
+ #import numpy.bool
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values = None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ def attn_residual_func(x: Tensor, attn_bias) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ attn_bias=attn_bias
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x, attn_bias))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x, attn_bias)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0, attn_bias=None
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset, attn_bias)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list, attn_bias=None):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list, attn_bias)
+ elif isinstance(x_or_x_list, list):
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x, others=None):
+ for b in self:
+ if others == None:
+ x = b(x)
+ else:
+ x = b(x, others)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ #init_values=None, # for layerscale: None or 0 => no layerscale
+ init_values=1e-5, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=NestedTensorBlock,
+ ffn_layer="mlp",
+ block_chunks=1,
+ window_size=37,
+ **kwargs
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.window_size = window_size
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ w0, h0 = w0 + 0.1, h0 + 0.1
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+ mode="bicubic",
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_patchtokens": x_norm[:, 1:],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ B, C, H, W = x.size()
+ pad_h = (self.patch_size - H % self.patch_size)
+ pad_w = (self.patch_size - W % self.patch_size)
+ if pad_h == self.patch_size:
+ pad_h = 0
+ if pad_w == self.patch_size:
+ pad_w = 0
+ #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2))
+ if pad_h + pad_w > 0:
+ x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear')
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ features = []
+ for blk in self.blocks:
+ x = blk(x)
+ # for idx in range(len(self.blocks[0])):
+ # x = self.blocks[0][idx](x)
+ # if (idx + 1) % (len(self.blocks[0]) // 4) == 0:
+ # features.append(x)
+
+ #return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
+
+ x_norm = self.norm(x)
+ # return {
+ # "x_norm_clstoken": x_norm[:, 0],
+ # "x_norm_patchtokens": x_norm[:, 1:],
+ # "x_prenorm": x,
+ # "masks": masks,
+ # }
+ features = []
+ features.append(x_norm)
+ features.append(x_norm)
+ features.append(x_norm)
+ features.append(x_norm)
+ return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ return ret
+ # if is_training:
+ # return ret
+ # else:
+ # return self.head(ret["x_norm_clstoken"])
+
+
+class PosConv(nn.Module):
+ # PEG from https://arxiv.org/abs/2102.10882
+ def __init__(self, in_chans, embed_dim=768, stride=1):
+ super(PosConv, self).__init__()
+ self.proj = nn.Sequential(
+ nn.Conv2d(in_chans, embed_dim, 37, stride, 18, bias=True, groups=embed_dim),
+ )
+ self.stride = stride
+
+ def forward(self, x, size):
+ B, N, C = x.shape
+ cnn_feat_token = x.transpose(1, 2).view(B, C, *size)
+ x = self.proj(cnn_feat_token)
+ if self.stride == 1:
+ x += cnn_feat_token
+ x = x.flatten(2).transpose(1, 2)
+ return x
+
+ #def no_weight_decay(self):
+ #return ['proj.%d.weight' % i for i in range(4)]
+
+class DinoWindowVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ #init_values=None, # for layerscale: None or 0 => no layerscale
+ init_values=1e-5, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=NestedTensorBlock,
+ ffn_layer="mlp",
+ block_chunks=1,
+ window_size=7,
+ **kwargs
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ #self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ #self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+
+ self.pos_conv = PosConv(self.embed_dim, self.embed_dim)
+
+ self.window_size = window_size
+ #self.conv_block = nn.ModuleList([ConvBlock(embed_dim) for i in range(4)])
+ #self.conv_block = nn.ModuleList([nn.Identity() for i in range(4)])
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.nh = -1
+ self.nw = -1
+ try:
+ H = cfg.data_basic['crop_size'][0]
+ W = cfg.data_basic['crop_size'][1]
+ pad_h = (self.patch_size - H % self.patch_size)
+ pad_w = (self.patch_size - W % self.patch_size)
+ if pad_h == self.patch_size:
+ pad_h = 0
+ if pad_w == self.patch_size:
+ pad_w = 0
+ self.nh = (H + pad_h) // self.patch_size
+ self.nw = (W + pad_w) // self.patch_size
+ self.prepare_attn_bias((self.nh, self.nw))
+ except:
+ pass
+ self.init_weights()
+
+ self.total_step = 10000 # For PE -> GPE transfer
+ self.start_step = 2000
+ self.current_step = 20000
+
+ def init_weights(self):
+ #trunc_normal_(self.pos_embed, std=0.02)
+ #nn.init.normal_(self.cls_token, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+ for i in range(4):
+ try:
+ nn.init.constant_(self.conv_block[i].conv2.weight, 0.0)
+ except:
+ pass
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ #npatch = x.shape[1] - 1
+ #N = self.pos_embed.shape[1] - 1
+ npatch = x.shape[1]
+ N = self.pos_embed.shape[1]
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ #class_pos_embed = pos_embed[:, 0]
+ #patch_pos_embed = pos_embed[:, 1:]
+ patch_pos_embed = pos_embed
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ w0, h0 = w0 + 0.1, h0 + 0.1
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+ mode="bicubic",
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return patch_pos_embed.to(previous_dtype)
+ #return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def window_partition(self, x: torch.Tensor, window_size: int, hw: Tuple[int, int], conv_feature=False) -> Tuple[torch.Tensor, Tuple[int, int]]:
+ """
+ Partition into non-overlapping windows with padding if needed.
+ Args:
+ x (tensor): input tokens with [B, H, W, C].
+ window_size (int): window size.
+
+ Returns:
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
+ (Hp, Wp): padded height and width before partition
+ """
+ if conv_feature == False:
+ B, N, C = x.shape
+ H, W = hw[0], hw[1]
+
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size * window_size, C)
+ else:
+ B, C, H, W = x.shape
+
+ x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
+
+ windows = x.permute(0, 2, 4, 3, 5, 1).contiguous().view(-1, window_size * window_size, C)
+
+ #y = torch.cat((x_cls, windows), dim=1)
+ return windows #, (Hp, Wp)
+
+
+ def window_unpartition(self,
+ windows: torch.Tensor, window_size: int, hw: Tuple[int, int], conv_feature=False
+ ) -> torch.Tensor:
+ """
+ Window unpartition into original sequences and removing padding.
+ Args:
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
+ window_size (int): window size.
+ pad_hw (Tuple): padded height and width (Hp, Wp).
+ hw (Tuple): original height and width (H, W) before padding.
+
+ Returns:
+ x: unpartitioned sequences with [B, H, W, C].
+ """
+ H, W = hw
+
+ B = windows.shape[0] // (H * W // window_size // window_size)
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+
+ if conv_feature == False:
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp * Wp, -1)
+ else:
+ C = windows.shape[-1]
+ x = x.permute(0, 5, 1, 3, 2, 4).contiguous().view(B, C, H, W)
+
+ # if Hp > H or Wp > W:
+ # x = x[:, :H, :W, :].contiguous()
+ return x
+
+ def prepare_tokens_with_masks(self, x, masks=None, step=-1):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ #x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ if step == -1:
+ step = self.current_step
+ else:
+ self.current_step = step
+
+ if step < self.start_step:
+ coef = 0.0
+ elif step < self.total_step:
+ coef = (step - self.start_step) / (self.total_step - self.start_step)
+ else:
+ coef = 1.0
+
+ x = x + (1 - coef) * self.interpolate_pos_encoding(x, w, h) + coef * self.pos_conv(x, (self.nh, self.nw))
+
+ return x
+
+ def prepare_attn_bias(self, shape):
+ window_size = self.window_size
+ if window_size <= 0:
+ return
+
+ import xformers.components.attention.attention_patterns as AP
+
+ nh, nw = shape
+ radius = (window_size-1)//2
+ mask_ori = AP.local_2d_pattern(nh, nw, distance = radius + 0.1, p=torch.inf).cuda()
+
+ pad = (8 - (nh * nw) % 8)
+ if pad == 8:
+ pad = 0
+ mask_pad = nn.functional.pad(mask_ori, (0, pad)).contiguous()
+ if pad > 0:
+ mask = mask_pad[:, :-pad].view(nh, nw, nh, nw)
+ else:
+ mask = mask_pad[:, :].view(nh, nw, nh, nw)
+
+ # angle
+ mask[:radius+1, :radius+1, :window_size, :window_size] = True
+ mask[:radius+1, -radius-1:, :window_size, -window_size:] = True
+ mask[-radius-1:, :radius+1, -window_size:, :window_size] = True
+ mask[-radius-1:, -radius-1:, -window_size:, -window_size:] = True
+
+ # edge
+ mask[radius+1:-radius-1, :radius+1, :, :] = mask[radius+1:-radius-1, radius:radius+1, :, :]
+ mask[radius+1:-radius-1, -radius-1:, :, :] = mask[radius+1:-radius-1, -radius-1:-radius, :, :]
+ mask[:radius+1, radius+1:-radius-1, :, :] = mask[radius:radius+1, radius+1:-radius-1, :, :]
+ mask[-radius-1:, radius+1:-radius-1, :, :] = mask[-radius-1:-radius, radius+1:-radius-1, :, :]
+
+ mask = mask.view(nh*nw, nh*nw)
+ bias_pad = torch.log(mask_pad)
+ #bias = bias_pad[:, :-pad]
+ self.register_buffer('attn_bias', bias_pad)
+
+ return bias_pad
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_patchtokens": x_norm[:, 1:],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None, **kwargs):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ B, C, H, W = x.size()
+ pad_h = (self.patch_size - H % self.patch_size)
+ pad_w = (self.patch_size - W % self.patch_size)
+ if pad_h == self.patch_size:
+ pad_h = 0
+ if pad_w == self.patch_size:
+ pad_w = 0
+ #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2))
+ if pad_h + pad_w > 0:
+ x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear')
+
+ nh = (H+pad_h)//self.patch_size
+ nw = (W+pad_w)//self.patch_size
+
+ if self.window_size > 0:
+ if nh == self.nh and nw == self.nw:
+ attn_bias = self.attn_bias
+ else:
+ attn_bias = self.prepare_attn_bias(((H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size))
+ self.nh = nh
+ self.nw = nw
+ attn_bias = attn_bias.unsqueeze(0).repeat(B * self.num_heads, 1, 1)
+ else:
+ attn_bias = None
+
+ x = self.prepare_tokens_with_masks(x, masks)
+ #x = self.patch_embed(x)
+
+ features = []
+ #x = self.window_partition(x, self.window_size, (H // self.patch_size, W // self.patch_size))
+ for blk in self.blocks:
+ x = blk(x, attn_bias)
+ #x = self.window_unpartition(x, self.window_size, (H // self.patch_size, W // self.patch_size))
+
+ # for idx in range(len(self.blocks[0])):
+ # x = self.blocks[0][idx](x, attn_bias)
+
+ # if (idx + 1) % (len(self.blocks[0]) // 4) == 0:
+ # x = self.window_unpartition(x, self.window_size, (H // self.patch_size, W // self.patch_size), conv_feature=True)
+ # x = self.conv_block[idx // (len(self.blocks[0]) // 4)](x)
+ # if idx + 1 != len(self.blocks[0]):
+ # x = self.window_partition(x, self.window_size, (H // self.patch_size, W // self.patch_size), conv_feature=True)
+ # else:
+ # b, c, h, w = x.size()
+ # x = x.permute(0, 2, 3, 1).contiguous().view(b, h, w, c)
+ #features.append(x)
+
+ #return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
+
+ x_norm = self.norm(x)
+ # return {
+ # "x_norm_clstoken": x_norm[:, 0],
+ # "x_norm_patchtokens": x_norm[:, 1:],
+ # "x_prenorm": x,
+ # "masks": masks,
+ # }
+ features = []
+ features.append(x_norm)
+ features.append(x_norm)
+ features.append(x_norm)
+ features.append(x_norm)
+ return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ return ret
+ # if is_training:
+ # return ret
+ # else:
+ # return self.head(ret["x_norm_clstoken"])
+
+
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=14, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=14, **kwargs):
+ model = DinoWindowVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=14, checkpoint=None, **kwargs):
+ model = DinoVisionTransformer(
+ img_size = 518,
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
+ **kwargs,
+ )
+
+ if checkpoint is not None:
+ with open(checkpoint, "rb") as f:
+ state_dict = torch.load(f)
+ try:
+ model.load_state_dict(state_dict, strict=True)
+ except:
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ if 'blocks' in key:
+ key_new = 'blocks.0' + key[len('blocks'):]
+ else:
+ key_new = key
+ new_state_dict[key_new] = value
+
+ model.load_state_dict(new_state_dict, strict=True)
+ #del model.norm
+ del model.mask_token
+ return model
+
+ # model = DinoWindowVisionTransformer(
+ # img_size = 518,
+ # patch_size=patch_size,
+ # embed_dim=1024,
+ # depth=24,
+ # num_heads=16,
+ # mlp_ratio=4,
+ # block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
+ # window_size=37,
+ # **kwargs,
+ # )
+
+ # if checkpoint is not None:
+ # with open(checkpoint, "rb") as f:
+ # state_dict = torch.load(f)
+ # try:
+ # model.load_state_dict(state_dict, strict=True)
+ # except:
+ # new_state_dict = {}
+ # for key, value in state_dict.items():
+ # if 'blocks' in key:
+ # key_new = 'blocks.0' + key[len('blocks'):]
+ # else:
+ # key_new = key
+ # if 'pos_embed' in key:
+ # value = value[:, 1:, :]
+ # new_state_dict[key_new] = value
+
+ # model.load_state_dict(new_state_dict, strict=False)
+ # #del model.norm
+ # del model.mask_token
+ return model
+
+
+def vit_giant2(patch_size=16, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+if __name__ == '__main__':
+ try:
+ from mmcv.utils import Config
+ except:
+ from mmengine import Config
+
+ #rgb = torch.rand((2, 3, 518, 518)).cuda()
+
+ #cfg.data_basic['crop_size']['0']
+ #cfg.data_basic['crop_size']['1']
+ cfg = Config.fromfile('/cpfs01/user/mu.hu/monodepth/mono/configs/HourglassDecoder/pub12.convlarge.0.3_150.py')
+
+ #rgb = torch.arange(0, 2*3*1036*1036, 1).cuda().float().view(2, 3, 1036, 1036)
+ rgb = torch.zeros(1, 3, 1400, 1680).cuda()
+ model = vit_large(checkpoint="/cpfs02/shared/public/custom/group_local_map/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth", kwarg=cfg).cuda()
+
+ #import timm
+ #model2 = timm.models.vision_transformer.vit_large_patch14_dinov2().cuda()
+ #timm.models.load_checkpoint(model2, '/cpfs02/shared/public/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth', filter_fn=timm.models.vision_transformer.checkpoint_filter_fn)
+
+ out1 = model(rgb)
+ #out2 = model2(rgb)
+ temp = 0
+
+
+
+# import time
+# window_size = 37
+# def prepare_window_masks(shape):
+# if window_size <= 0:
+# return None
+# import xformers.components.attention.attention_patterns as AP
+
+# B, nh, nw, _, _ = shape
+# radius = (window_size-1)//2
+# #time0 = time.time()
+# d = AP.local_nd_distance(nh, nw, distance = radius + 0.1, p=torch.inf).cuda()
+# #mask = AP.local_2d_pattern(nh, nw, distance = radius + 0.1, p=torch.inf).cuda()
+# # mask = mask.view(nh, nw, nh, nw)
+# # #time1 = time.time() - time0
+
+# # # angle
+# # mask[:radius+1, :radius+1, :window_size, :window_size] = True
+# # mask[:radius+1, -radius-1:, :window_size, -window_size:] = True
+# # mask[-radius-1:, :radius+1, -window_size:, :window_size] = True
+# # mask[-radius-1:, -radius-1:, -window_size:, -window_size:] = True
+# # time2 = time.time() - time0 - time1
+
+# # # edge
+# # mask[radius+1:-radius-1, :radius+1, :, :] = mask[radius+1:-radius-1, radius:radius+1, :, :]
+# # mask[radius+1:-radius-1, -radius-1:, :, :] = mask[radius+1:-radius-1, -radius-1:-radius, :, :]
+# # mask[:radius+1, radius+1:-radius-1, :, :] = mask[radius:radius+1, radius+1:-radius-1, :, :]
+# # mask[-radius-1:, radius+1:-radius-1, :, :] = mask[-radius-1:-radius, radius+1:-radius-1, :, :]
+# # time3 = time.time() - time0 - time2
+# # print(time1, time2, time3)
+
+# # return mask.view(nw*nw, nh*nw).unsqueeze(0).repeat(B, 1)
+
+# shape = (1, 55, 55, None, None)
+# mask = prepare_window_masks(shape)
+# # temp = 1
\ No newline at end of file
diff --git a/mono/model/backbones/ViT_DINO_reg.py b/mono/model/backbones/ViT_DINO_reg.py
new file mode 100644
index 0000000000000000000000000000000000000000..854f96320ea93752e023c8cd845bf38353dfab17
--- /dev/null
+++ b/mono/model/backbones/ViT_DINO_reg.py
@@ -0,0 +1,1293 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable, Optional, Dict, Any, List
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+import torch.nn.init
+import torch.nn.functional as F
+
+#from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+logger = logging.getLogger("dinov2")
+
+# SSF finetuning originally by dongzelian
+def init_ssf_scale_shift(dim):
+ scale = nn.Parameter(torch.ones(dim))
+ shift = nn.Parameter(torch.zeros(dim))
+
+ nn.init.normal_(scale, mean=1, std=.02)
+ nn.init.normal_(shift, std=.02)
+
+ return scale, shift
+
+def ssf_ada(x, scale, shift):
+ assert scale.shape == shift.shape
+ if x.shape[-1] == scale.shape[0]:
+ return x * scale + shift
+ elif x.shape[1] == scale.shape[0]:
+ return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1)
+ else:
+ raise ValueError('the input tensor shape does not match the shape of the scale factor.')
+
+# LoRA finetuning originally by edwardjhu
+class LoRALayer():
+ def __init__(
+ self,
+ r: int,
+ lora_alpha: int,
+ lora_dropout: float,
+ merge_weights: bool,
+ ):
+ self.r = r
+ self.lora_alpha = lora_alpha
+ # Optional dropout
+ if lora_dropout > 0.:
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
+ else:
+ self.lora_dropout = lambda x: x
+ # Mark the weight as unmerged
+ self.merged = False
+ self.merge_weights = merge_weights
+
+class LoRALinear(nn.Linear, LoRALayer):
+ # LoRA implemented in a dense layer
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ r: int = 0,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.,
+ fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
+ merge_weights: bool = True,
+ **kwargs
+ ):
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
+ merge_weights=merge_weights)
+
+ self.fan_in_fan_out = fan_in_fan_out
+ # Actual trainable parameters
+ if r > 0:
+ self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
+ self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+ self.reset_parameters()
+ if fan_in_fan_out:
+ self.weight.data = self.weight.data.transpose(0, 1)
+
+ def reset_parameters(self):
+ #nn.Linear.reset_parameters(self)
+ if hasattr(self, 'lora_A'):
+ # initialize B the same way as the default for nn.Linear and A to zero
+ # this is different than what is described in the paper but should not affect performance
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+
+ # def train(self, mode: bool = True):
+ # def T(w):
+ # return w.transpose(0, 1) if self.fan_in_fan_out else w
+ # nn.Linear.train(self, mode)
+ # if mode:
+ # if self.merge_weights and self.merged:
+ # # Make sure that the weights are not merged
+ # if self.r > 0:
+ # self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
+ # self.merged = False
+ # else:
+ # if self.merge_weights and not self.merged:
+ # # Merge the weights and mark it
+ # if self.r > 0:
+ # self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
+ # self.merged = True
+
+ def forward(self, x: torch.Tensor):
+ def T(w):
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
+ if self.r > 0 and not self.merged:
+ result = F.linear(x, T(self.weight), bias=self.bias)
+ result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
+ return result
+ else:
+ return F.linear(x, T(self.weight), bias=self.bias)
+
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ tuning_mode: Optional[str] = None
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ if tuning_mode != None:
+ self.tuning_mode = tuning_mode
+ if tuning_mode == 'ssf':
+ self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim)
+ else:
+ pass
+ #raise NotImplementedError()
+ else:
+ self.tuning_mode = None
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if self.tuning_mode == 'ssf':
+ x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ tuning_mode: Optional[int] = None
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ if tuning_mode != None:
+ self.tuning_mode = tuning_mode
+ if tuning_mode == 'ssf':
+ self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features)
+ self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features)
+ else:
+ pass
+ #raise NotImplementedError()
+ else:
+ self.tuning_mode = None
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ if self.tuning_mode == 'ssf':
+ x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1)
+
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ if self.tuning_mode == 'ssf':
+ x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)
+
+ x = self.drop(x)
+ return x
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ tuning_mode: Optional[int] = None
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ if tuning_mode != None:
+ self.tuning_mode = tuning_mode
+ if tuning_mode == 'ssf':
+ self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(2 * hidden_features)
+ self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features)
+ else:
+ pass
+ #raise NotImplementedError()
+ else:
+ self.tuning_mode = None
+
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ if self.tuning_mode == 'ssf':
+ x12 = ssf_ada(x12, self.ssf_scale_1, self.ssf_shift_1)
+
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ out = self.w3(hidden)
+
+ if self.tuning_mode == 'ssf':
+ out = ssf_ada(out, self.ssf_scale_2, self.ssf_scale_2)
+
+ return out
+
+
+try:
+ from xformers.ops import SwiGLU
+ #import numpy.bool
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
+
+
+try:
+ from xformers.ops import memory_efficient_attention, unbind, fmha
+ from xformers.components.attention import ScaledDotProduct
+ from xformers.components import MultiHeadDispatch
+ #import numpy.bool
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ window_size: int = 0,
+ tuning_mode: Optional[int] = None
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ if tuning_mode == 'lora':
+ self.tuning_mode = tuning_mode
+ self.qkv = LoRALinear(dim, dim * 3, bias=qkv_bias, r=8)
+ else:
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+
+ self.attn_drop = nn.Dropout(attn_drop)
+
+ if tuning_mode == 'lora':
+ self.tuning_mode = tuning_mode
+ self.proj = LoRALinear(dim, dim, bias=proj_bias, r=8)
+ else:
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ if tuning_mode != None:
+ self.tuning_mode = tuning_mode
+ if tuning_mode == 'ssf':
+ self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim * 3)
+ self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)
+ else:
+ pass
+ #raise NotImplementedError()
+ else:
+ self.tuning_mode = None
+
+ #if not self.training:
+ #
+ # self.attn = ScaledDotProduct()
+ #self.attn = MultiHeadDispatch(dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, attention=attn)
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ B, N, C = x.shape
+ if self.tuning_mode == 'ssf':
+ qkv = ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ else:
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ if attn_bias is not None:
+ attn = attn + attn_bias[:, :, :N]
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+
+ if self.tuning_mode == 'ssf':
+ x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)
+
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ #if True:
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
+ return super().forward(x, attn_bias)
+
+ B, N, C = x.shape
+ if self.tuning_mode == 'ssf':
+ qkv = ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+ else:
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+ if attn_bias is not None:
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias[:, :, :N])
+ else:
+ x = memory_efficient_attention(q, k, v)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ if self.tuning_mode == 'ssf':
+ x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2)
+
+ x = self.proj_drop(x)
+ return x
+
+try:
+ from xformers.ops import fmha
+ from xformers.ops import scaled_index_add, index_select_cat
+ #import numpy.bool
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values = None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ tuning_mode: Optional[int] = None
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ tuning_mode=tuning_mode
+ )
+
+ if tuning_mode != None:
+ self.tuning_mode = tuning_mode
+ if tuning_mode == 'ssf':
+ self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim)
+ self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim)
+ else:
+ pass
+ #raise NotImplementedError()
+ else:
+ self.tuning_mode = None
+
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ def attn_residual_func(x: Tensor, attn_bias) -> Tensor:
+ if self.tuning_mode == 'ssf':
+ return self.ls1(self.attn(ssf_ada(self.norm1(x), self.ssf_scale_1, self.ssf_shift_1), attn_bias))
+ else:
+ return self.ls1(self.attn(self.norm1(x), attn_bias))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ if self.tuning_mode == 'ssf':
+ return self.ls2(self.mlp(ssf_ada(self.norm2(x), self.ssf_scale_2, self.ssf_shift_2)))
+ else:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ attn_bias=attn_bias
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x, attn_bias))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x, attn_bias)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0, attn_bias=None
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset, attn_bias)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list, attn_bias=None):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list, attn_bias)
+ elif isinstance(x_or_x_list, list):
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x, others=None):
+ for b in self:
+ if others == None:
+ x = b(x)
+ else:
+ x = b(x, others)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=518,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=1e-5, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ tuning_mode=None,
+ **kwargs
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ if tuning_mode != None:
+ self.tuning_mode = tuning_mode
+ if tuning_mode == 'ssf':
+ self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim)
+ else:
+ pass
+ #raise NotImplementedError()
+ else:
+ self.tuning_mode = None
+ tuning_mode_list = [tuning_mode] * depth
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, tuning_mode=tuning_mode)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ tuning_mode=tuning_mode_list[i]
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
+
+ sqrt_N = math.sqrt(N)
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
+ scale_factor=(sx, sy),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2]
+ assert int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ B, C, H, W = x.size()
+ pad_h = (self.patch_size - H % self.patch_size)
+ pad_w = (self.patch_size - W % self.patch_size)
+ if pad_h == self.patch_size:
+ pad_h = 0
+ if pad_w == self.patch_size:
+ pad_w = 0
+ #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2))
+ if pad_h + pad_w > 0:
+ x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear')
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x_norm = self.norm(x)
+ if self.tuning_mode == 'ssf':
+ x_norm = ssf_ada(x_norm, self.ssf_scale_1, self.ssf_shift_1)
+
+ # return {
+ # "x_norm_clstoken": x_norm[:, 0],
+ # "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ # "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ # "x_prenorm": x,
+ # "masks": masks,
+ # }
+ features = []
+ features.append(x_norm)
+ features.append(x_norm)
+ features.append(x_norm)
+ features.append(x_norm)
+ return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W, self.num_register_tokens)]
+
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ return ret
+ # if is_training:
+ # return ret
+ # else:
+ # return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def load_ckpt_dino(checkpoint, model):
+ if checkpoint is not None:
+ try:
+ with open(checkpoint, "rb") as f:
+ state_dict = torch.load(f)
+ except:
+ print('NO pretrained imagenet ckpt available! Check your path!')
+ del model.mask_token
+ return
+
+ try:
+ model.load_state_dict(state_dict, strict=True)
+ except:
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ if 'blocks' in key:
+ key_new = 'blocks.0' + key[len('blocks'):]
+ else:
+ key_new = key
+ new_state_dict[key_new] = value
+
+ model.load_state_dict(new_state_dict, strict=True)
+ del model.mask_token
+ return
+ else:
+ return
+
+
+def vit_small(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+
+ load_ckpt_dino(checkpoint, model)
+
+ return model
+
+
+def vit_base(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+
+ if checkpoint is not None:
+ with open(checkpoint, "rb") as f:
+ state_dict = torch.load(f)
+ try:
+ model.load_state_dict(state_dict, strict=True)
+ except:
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ if 'blocks' in key:
+ key_new = 'blocks.0' + key[len('blocks'):]
+ else:
+ key_new = key
+ new_state_dict[key_new] = value
+
+ model.load_state_dict(new_state_dict, strict=True)
+ del model.mask_token
+ return model
+
+
+def vit_giant2(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ ffn_layer='swiglu',
+ **kwargs,
+ )
+ return model
+
+
+
+def vit_small_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ tuning_mode=tuning_mode,
+ **kwargs,
+ )
+
+ load_ckpt_dino(checkpoint, model)
+
+ return model
+
+
+def vit_base_reg(patch_size=14, num_register_tokens=4, checkpoint=None, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+
+ load_ckpt_dino(checkpoint, model)
+
+ return model
+
+
+def vit_large_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs):
+ model = DinoVisionTransformer(
+ img_size = 518,
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ tuning_mode=tuning_mode,
+ **kwargs,
+ )
+
+ load_ckpt_dino(checkpoint, model)
+
+ return model
+
+
+def vit_giant2_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ ffn_layer='swiglu',
+ tuning_mode=tuning_mode,
+ **kwargs,
+ )
+
+ load_ckpt_dino(checkpoint, model)
+
+ return model
+
+if __name__ == '__main__':
+ try:
+ from mmcv.utils import Config
+ except:
+ from mmengine import Config
+
+ #rgb = torch.rand((2, 3, 518, 518)).cuda()
+
+ #cfg.data_basic['crop_size']['0']
+ #cfg.data_basic['crop_size']['1']
+ cfg = Config.fromfile('/opt/ml/project/mu.hu/projects/monodepth_vit/mono/configs/RAFTDecoder/vit.raft5.large.kitti.py')
+
+ #rgb = torch.arange(0, 2*3*1036*1036, 1).cuda().float().view(2, 3, 1036, 1036)
+ rgb = torch.zeros(1, 3, 616, 1064).cuda()
+ cfg['tuning_mode'] = 'ssf'
+ #model = vit_large_reg(checkpoint="/cpfs02/shared/public/groups/local_map/yvan/pretrained_weight_repo/vit/dinov2_vitl14_reg4_pretrain.pth", kwarg=cfg).cuda()
+ model = vit_large_reg(tuning_mode='ssf').cuda()
+
+ #import timm
+ #model2 = timm.models.vision_transformer.vit_large_patch14_dinov2().cuda()
+ #timm.models.load_checkpoint(model2, '/cpfs02/shared/public/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth', filter_fn=timm.models.vision_transformer.checkpoint_filter_fn)
+
+ out1 = model(rgb)
+ #out2 = model2(rgb)
+ temp = 0
+
+
diff --git a/mono/model/backbones/__init__.py b/mono/model/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cc3ba70ef5ef867f0518d73a189e7531466cbab
--- /dev/null
+++ b/mono/model/backbones/__init__.py
@@ -0,0 +1,11 @@
+from .ConvNeXt import convnext_xlarge
+from .ConvNeXt import convnext_small
+from .ConvNeXt import convnext_base
+from .ConvNeXt import convnext_large
+from .ConvNeXt import convnext_tiny
+from .ViT_DINO import vit_large
+from .ViT_DINO_reg import vit_small_reg, vit_large_reg
+
+__all__ = [
+ 'convnext_xlarge', 'convnext_small', 'convnext_base', 'convnext_large', 'convnext_tiny', 'vit_small_reg', 'vit_large_reg'
+]
diff --git a/mono/model/decode_heads/HourGlassDecoder.py b/mono/model/decode_heads/HourGlassDecoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e084382601e21e6ce5144abbd6a65f563905b659
--- /dev/null
+++ b/mono/model/decode_heads/HourGlassDecoder.py
@@ -0,0 +1,274 @@
+import torch
+import torch.nn as nn
+import numpy as np
+import math
+import torch.nn.functional as F
+
+def compute_depth_expectation(prob, depth_values):
+ depth_values = depth_values.view(*depth_values.shape, 1, 1)
+ depth = torch.sum(prob * depth_values, 1)
+ return depth
+
+class ConvBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=3):
+ super(ConvBlock, self).__init__()
+
+ if kernel_size == 3:
+ self.conv = nn.Sequential(
+ nn.ReflectionPad2d(1),
+ nn.Conv2d(in_channels, out_channels, 3, padding=0, stride=1),
+ )
+ elif kernel_size == 1:
+ self.conv = nn.Conv2d(int(in_channels), int(out_channels), 1, padding=0, stride=1)
+
+ self.nonlin = nn.ELU(inplace=True)
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.nonlin(out)
+ return out
+
+
+class ConvBlock_double(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=3):
+ super(ConvBlock_double, self).__init__()
+
+ if kernel_size == 3:
+ self.conv = nn.Sequential(
+ nn.ReflectionPad2d(1),
+ nn.Conv2d(in_channels, out_channels, 3, padding=0, stride=1),
+ )
+ elif kernel_size == 1:
+ self.conv = nn.Conv2d(int(in_channels), int(out_channels), 1, padding=0, stride=1)
+
+ self.nonlin = nn.ELU(inplace=True)
+ self.conv_2 = nn.Conv2d(out_channels, out_channels, 1, padding=0, stride=1)
+ self.nonlin_2 =nn.ELU(inplace=True)
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.nonlin(out)
+ out = self.conv_2(out)
+ out = self.nonlin_2(out)
+ return out
+
+class DecoderFeature(nn.Module):
+ def __init__(self, feat_channels, num_ch_dec=[64, 64, 128, 256]):
+ super(DecoderFeature, self).__init__()
+ self.num_ch_dec = num_ch_dec
+ self.feat_channels = feat_channels
+
+ self.upconv_3_0 = ConvBlock(self.feat_channels[3], self.num_ch_dec[3], kernel_size=1)
+ self.upconv_3_1 = ConvBlock_double(
+ self.feat_channels[2] + self.num_ch_dec[3],
+ self.num_ch_dec[3],
+ kernel_size=1)
+
+ self.upconv_2_0 = ConvBlock(self.num_ch_dec[3], self.num_ch_dec[2], kernel_size=3)
+ self.upconv_2_1 = ConvBlock_double(
+ self.feat_channels[1] + self.num_ch_dec[2],
+ self.num_ch_dec[2],
+ kernel_size=3)
+
+ self.upconv_1_0 = ConvBlock(self.num_ch_dec[2], self.num_ch_dec[1], kernel_size=3)
+ self.upconv_1_1 = ConvBlock_double(
+ self.feat_channels[0] + self.num_ch_dec[1],
+ self.num_ch_dec[1],
+ kernel_size=3)
+ self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
+
+ def forward(self, ref_feature):
+ x = ref_feature[3]
+
+ x = self.upconv_3_0(x)
+ x = torch.cat((self.upsample(x), ref_feature[2]), 1)
+ x = self.upconv_3_1(x)
+
+ x = self.upconv_2_0(x)
+ x = torch.cat((self.upsample(x), ref_feature[1]), 1)
+ x = self.upconv_2_1(x)
+
+ x = self.upconv_1_0(x)
+ x = torch.cat((self.upsample(x), ref_feature[0]), 1)
+ x = self.upconv_1_1(x)
+ return x
+
+
+class UNet(nn.Module):
+ def __init__(self, inp_ch=32, output_chal=1, down_sample_times=3, channel_mode='v0'):
+ super(UNet, self).__init__()
+ basic_block = ConvBnReLU
+ num_depth = 128
+
+ self.conv0 = basic_block(inp_ch, num_depth)
+ if channel_mode == 'v0':
+ channels = [num_depth, num_depth//2, num_depth//4, num_depth//8, num_depth // 8]
+ elif channel_mode == 'v1':
+ channels = [num_depth, num_depth, num_depth, num_depth, num_depth, num_depth]
+ self.down_sample_times = down_sample_times
+ for i in range(down_sample_times):
+ setattr(
+ self, 'conv_%d' % i,
+ nn.Sequential(
+ basic_block(channels[i], channels[i+1], stride=2),
+ basic_block(channels[i+1], channels[i+1])
+ )
+ )
+ for i in range(down_sample_times-1,-1,-1):
+ setattr(self, 'deconv_%d' % i,
+ nn.Sequential(
+ nn.ConvTranspose2d(
+ channels[i+1],
+ channels[i],
+ kernel_size=3,
+ padding=1,
+ output_padding=1,
+ stride=2,
+ bias=False),
+ nn.BatchNorm2d(channels[i]),
+ nn.ReLU(inplace=True)
+ )
+ )
+ self.prob = nn.Conv2d(num_depth, output_chal, 1, stride=1, padding=0)
+
+ def forward(self, x):
+ features = {}
+ conv0 = self.conv0(x)
+ x = conv0
+ features[0] = conv0
+ for i in range(self.down_sample_times):
+ x = getattr(self, 'conv_%d' % i)(x)
+ features[i+1] = x
+ for i in range(self.down_sample_times-1,-1,-1):
+ x = features[i] + getattr(self, 'deconv_%d' % i)(x)
+ x = self.prob(x)
+ return x
+
+class ConvBnReLU(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
+ super(ConvBnReLU, self).__init__()
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=pad,
+ bias=False
+ )
+ self.bn = nn.BatchNorm2d(out_channels)
+
+ def forward(self, x):
+ return F.relu(self.bn(self.conv(x)), inplace=True)
+
+
+class HourglassDecoder(nn.Module):
+ def __init__(self, cfg):
+ super(HourglassDecoder, self).__init__()
+ self.inchannels = cfg.model.decode_head.in_channels # [256, 512, 1024, 2048]
+ self.decoder_channels = cfg.model.decode_head.decoder_channel # [64, 64, 128, 256]
+ self.min_val = cfg.data_basic.depth_normalize[0]
+ self.max_val = cfg.data_basic.depth_normalize[1]
+
+ self.num_ch_dec = self.decoder_channels # [64, 64, 128, 256]
+ self.num_depth_regressor_anchor = 512
+ self.feat_channels = self.inchannels
+ unet_in_channel = self.num_ch_dec[1]
+ unet_out_channel = 256
+
+ self.decoder_mono = DecoderFeature(self.feat_channels, self.num_ch_dec)
+ self.conv_out_2 = UNet(inp_ch=unet_in_channel,
+ output_chal=unet_out_channel + 1,
+ down_sample_times=3,
+ channel_mode='v0',
+ )
+
+ self.depth_regressor_2 = nn.Sequential(
+ nn.Conv2d(unet_out_channel,
+ self.num_depth_regressor_anchor,
+ kernel_size=3,
+ padding=1,
+ ),
+ nn.BatchNorm2d(self.num_depth_regressor_anchor),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(
+ self.num_depth_regressor_anchor,
+ self.num_depth_regressor_anchor,
+ kernel_size=1,
+ )
+ )
+ self.residual_channel = 16
+ self.conv_up_2 = nn.Sequential(
+ nn.Conv2d(1 + 2 + unet_out_channel, self.residual_channel, 3, padding=1),
+ nn.BatchNorm2d(self.residual_channel),
+ nn.ReLU(),
+ nn.Conv2d(self.residual_channel, self.residual_channel, 3, padding=1),
+ nn.Upsample(scale_factor=4),
+ nn.Conv2d(self.residual_channel, self.residual_channel, 3, padding=1),
+ nn.ReLU(),
+ nn.Conv2d(self.residual_channel, 1, 1, padding=0),
+ )
+
+ def get_bins(self, bins_num):
+ depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device='cuda')
+ depth_bins_vec = torch.exp(depth_bins_vec)
+ return depth_bins_vec
+
+ def register_depth_expectation_anchor(self, bins_num, B):
+ depth_bins_vec = self.get_bins(bins_num)
+ depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1)
+ self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False)
+
+ def upsample(self, x, scale_factor=2):
+ return F.interpolate(x, scale_factor=scale_factor, mode='nearest')
+
+ def regress_depth_2(self, feature_map_d):
+ prob = self.depth_regressor_2(feature_map_d).softmax(dim=1)
+ B = prob.shape[0]
+ if "depth_expectation_anchor" not in self._buffers:
+ self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, B)
+ d = compute_depth_expectation(
+ prob,
+ self.depth_expectation_anchor[:B, ...]
+ ).unsqueeze(1)
+ return d
+
+ def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True):
+ y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device),
+ torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij')
+ meshgrid = torch.stack((x, y))
+ meshgrid = meshgrid.unsqueeze(0).repeat(batch, 1, 1, 1)
+ return meshgrid
+
+ def forward(self, features_mono, **kwargs):
+ '''
+ trans_ref2src: list of transformation matrix from the reference view to source view. [B, 4, 4]
+ inv_intrinsic_pool: list of inverse intrinsic matrix.
+ features_mono: features of reference and source views. [[ref_f1, ref_f2, ref_f3, ref_f4],[src1_f1, src1_f2, src1_f3, src1_f4], ...].
+ '''
+ outputs = {}
+ # get encoder feature of the reference view
+ ref_feat = features_mono
+
+ feature_map_mono = self.decoder_mono(ref_feat)
+ feature_map_mono_pred = self.conv_out_2(feature_map_mono)
+ confidence_map_2 = feature_map_mono_pred[:, -1:, :, :]
+ feature_map_d_2 = feature_map_mono_pred[:, :-1, :, :]
+
+ depth_pred_2 = self.regress_depth_2(feature_map_d_2)
+
+ B, _, H, W = depth_pred_2.shape
+
+ meshgrid = self.create_mesh_grid(H, W, B)
+
+ depth_pred_mono = self.upsample(depth_pred_2, scale_factor=4) + 1e-1 * \
+ self.conv_up_2(
+ torch.cat((depth_pred_2, meshgrid[:B, ...], feature_map_d_2), 1)
+ )
+ confidence_map_mono = self.upsample(confidence_map_2, scale_factor=4)
+
+ outputs=dict(
+ prediction=depth_pred_mono,
+ confidence=confidence_map_mono,
+ pred_logit=None,
+ )
+ return outputs
\ No newline at end of file
diff --git a/mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py b/mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py
new file mode 100644
index 0000000000000000000000000000000000000000..9af89f9b4b1878a2e4bcfcd489075c2e97cd8d3d
--- /dev/null
+++ b/mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py
@@ -0,0 +1,1033 @@
+import torch
+import torch.nn as nn
+import numpy as np
+import math
+import torch.nn.functional as F
+
+# LORA finetuning originally by edwardjhu
+class LoRALayer():
+ def __init__(
+ self,
+ r: int,
+ lora_alpha: int,
+ lora_dropout: float,
+ merge_weights: bool,
+ ):
+ self.r = r
+ self.lora_alpha = lora_alpha
+ # Optional dropout
+ if lora_dropout > 0.:
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
+ else:
+ self.lora_dropout = lambda x: x
+ # Mark the weight as unmerged
+ self.merged = False
+ self.merge_weights = merge_weights
+
+class LoRALinear(nn.Linear, LoRALayer):
+ # LoRA implemented in a dense layer
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ r: int = 0,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.,
+ fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
+ merge_weights: bool = True,
+ **kwargs
+ ):
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
+ merge_weights=merge_weights)
+
+ self.fan_in_fan_out = fan_in_fan_out
+ # Actual trainable parameters
+ if r > 0:
+ self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
+ self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+ self.reset_parameters()
+ if fan_in_fan_out:
+ self.weight.data = self.weight.data.transpose(0, 1)
+
+ def reset_parameters(self):
+ #nn.Linear.reset_parameters(self)
+ if hasattr(self, 'lora_A'):
+ # initialize B the same way as the default for nn.Linear and A to zero
+ # this is different than what is described in the paper but should not affect performance
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+
+ # def train(self, mode: bool = True):
+ # def T(w):
+ # return w.transpose(0, 1) if self.fan_in_fan_out else w
+ # nn.Linear.train(self, mode)
+ # if mode:
+ # if self.merge_weights and self.merged:
+ # # Make sure that the weights are not merged
+ # if self.r > 0:
+ # self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
+ # self.merged = False
+ # else:
+ # if self.merge_weights and not self.merged:
+ # # Merge the weights and mark it
+ # if self.r > 0:
+ # self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
+ # self.merged = True
+
+ def forward(self, x: torch.Tensor):
+ def T(w):
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
+ if self.r > 0 and not self.merged:
+ result = F.linear(x, T(self.weight), bias=self.bias)
+ result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
+ return result
+ else:
+ return F.linear(x, T(self.weight), bias=self.bias)
+
+class ConvLoRA(nn.Conv2d, LoRALayer):
+ def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
+ #self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
+ nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
+ assert isinstance(kernel_size, int)
+
+ # Actual trainable parameters
+ if r > 0:
+ self.lora_A = nn.Parameter(
+ self.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
+ )
+ self.lora_B = nn.Parameter(
+ self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
+ )
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+ self.reset_parameters()
+ self.merged = False
+
+ def reset_parameters(self):
+ #self.conv.reset_parameters()
+ if hasattr(self, 'lora_A'):
+ # initialize A the same way as the default for nn.Linear and B to zero
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+
+ # def train(self, mode=True):
+ # super(ConvLoRA, self).train(mode)
+ # if mode:
+ # if self.merge_weights and self.merged:
+ # if self.r > 0:
+ # # Make sure that the weights are not merged
+ # self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
+ # self.merged = False
+ # else:
+ # if self.merge_weights and not self.merged:
+ # if self.r > 0:
+ # # Merge the weights and mark it
+ # self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
+ # self.merged = True
+
+ def forward(self, x):
+ if self.r > 0 and not self.merged:
+ # return self.conv._conv_forward(
+ # x,
+ # self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling,
+ # self.conv.bias
+ # )
+ weight = self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
+ bias = self.bias
+
+ return F.conv2d(x, weight, bias=bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
+ else:
+ return F.conv2d(x, self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
+
+class ConvTransposeLoRA(nn.ConvTranspose2d, LoRALayer):
+ def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):
+ #self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)
+ nn.ConvTranspose2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
+ assert isinstance(kernel_size, int)
+
+ # Actual trainable parameters
+ if r > 0:
+ self.lora_A = nn.Parameter(
+ self.weight.new_zeros((r * kernel_size, in_channels * kernel_size))
+ )
+ self.lora_B = nn.Parameter(
+ self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
+ )
+ self.scaling = self.lora_alpha / self.r
+ # Freezing the pre-trained weight matrix
+ self.weight.requires_grad = False
+ self.reset_parameters()
+ self.merged = False
+
+ def reset_parameters(self):
+ #self.conv.reset_parameters()
+ if hasattr(self, 'lora_A'):
+ # initialize A the same way as the default for nn.Linear and B to zero
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
+ nn.init.zeros_(self.lora_B)
+
+ # def train(self, mode=True):
+ # super(ConvTransposeLoRA, self).train(mode)
+ # if mode:
+ # if self.merge_weights and self.merged:
+ # if self.r > 0:
+ # # Make sure that the weights are not merged
+ # self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
+ # self.merged = False
+ # else:
+ # if self.merge_weights and not self.merged:
+ # if self.r > 0:
+ # # Merge the weights and mark it
+ # self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling
+ # self.merged = True
+
+ def forward(self, x):
+ if self.r > 0 and not self.merged:
+ weight = self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
+ bias = self.bias
+ return F.conv_transpose2d(x, weight,
+ bias=bias, stride=self.stride, padding=self.padding, output_padding=self.output_padding,
+ groups=self.groups, dilation=self.dilation)
+ else:
+ return F.conv_transpose2d(x, self.weight,
+ bias=self.bias, stride=self.stride, padding=self.padding, output_padding=self.output_padding,
+ groups=self.groups, dilation=self.dilation)
+ #return self.conv(x)
+
+class Conv2dLoRA(ConvLoRA):
+ def __init__(self, *args, **kwargs):
+ super(Conv2dLoRA, self).__init__(*args, **kwargs)
+
+class ConvTranspose2dLoRA(ConvTransposeLoRA):
+ def __init__(self, *args, **kwargs):
+ super(ConvTranspose2dLoRA, self).__init__(*args, **kwargs)
+
+
+def compute_depth_expectation(prob, depth_values):
+ depth_values = depth_values.view(*depth_values.shape, 1, 1)
+ depth = torch.sum(prob * depth_values, 1)
+ return depth
+
+def interpolate_float32(x, size=None, scale_factor=None, mode='nearest', align_corners=None):
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
+ return F.interpolate(x.float(), size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
+
+# def upflow8(flow, mode='bilinear'):
+# new_size = (8 * flow.shape[2], 8 * flow.shape[3])
+# return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
+
+def upflow4(flow, mode='bilinear'):
+ new_size = (4 * flow.shape[2], 4 * flow.shape[3])
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
+ return F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
+
+def coords_grid(batch, ht, wd):
+ # coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
+ coords = (torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)))
+ coords = torch.stack(coords[::-1], dim=0).float()
+ return coords[None].repeat(batch, 1, 1, 1)
+
+def norm_normalize(norm_out):
+ min_kappa = 0.01
+ norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1)
+ norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10
+ kappa = F.elu(kappa) + 1.0 + min_kappa
+ final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1)
+ return final_out
+
+# uncertainty-guided sampling (only used during training)
+@torch.no_grad()
+def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta):
+ device = init_normal.device
+ B, _, H, W = init_normal.shape
+ N = int(sampling_ratio * H * W)
+ beta = beta
+
+ # uncertainty map
+ uncertainty_map = -1 * init_normal[:, -1, :, :] # B, H, W
+
+ # gt_invalid_mask (B, H, W)
+ if gt_norm_mask is not None:
+ gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest')
+ gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5
+ uncertainty_map[gt_invalid_mask] = -1e4
+
+ # (B, H*W)
+ _, idx = uncertainty_map.view(B, -1).sort(1, descending=True)
+
+ # importance sampling
+ if int(beta * N) > 0:
+ importance = idx[:, :int(beta * N)] # B, beta*N
+
+ # remaining
+ remaining = idx[:, int(beta * N):] # B, H*W - beta*N
+
+ # coverage
+ num_coverage = N - int(beta * N)
+
+ if num_coverage <= 0:
+ samples = importance
+ else:
+ coverage_list = []
+ for i in range(B):
+ idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
+ coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
+ coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
+ samples = torch.cat((importance, coverage), dim=1) # B, N
+
+ else:
+ # remaining
+ remaining = idx[:, :] # B, H*W
+
+ # coverage
+ num_coverage = N
+
+ coverage_list = []
+ for i in range(B):
+ idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
+ coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
+ coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
+ samples = coverage
+
+ # point coordinates
+ rows_int = samples // W # 0 for first row, H-1 for last row
+ rows_float = rows_int / float(H-1) # 0 to 1.0
+ rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0
+
+ cols_int = samples % W # 0 for first column, W-1 for last column
+ cols_float = cols_int / float(W-1) # 0 to 1.0
+ cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0
+
+ point_coords = torch.zeros(B, 1, N, 2)
+ point_coords[:, 0, :, 0] = cols_float # x coord
+ point_coords[:, 0, :, 1] = rows_float # y coord
+ point_coords = point_coords.to(device)
+ return point_coords, rows_int, cols_int
+
+class FlowHead(nn.Module):
+ def __init__(self, input_dim=128, hidden_dim=256, output_dim_depth=2, output_dim_norm=4, tuning_mode=None):
+ super(FlowHead, self).__init__()
+ self.conv1d = Conv2dLoRA(input_dim, hidden_dim // 2, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
+ self.conv2d = Conv2dLoRA(hidden_dim // 2, output_dim_depth, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
+
+ self.conv1n = Conv2dLoRA(input_dim, hidden_dim // 2, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
+ self.conv2n = Conv2dLoRA(hidden_dim // 2, output_dim_norm, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ depth = self.conv2d(self.relu(self.conv1d(x)))
+ normal = self.conv2n(self.relu(self.conv1n(x)))
+ return torch.cat((depth, normal), dim=1)
+
+
+class ConvGRU(nn.Module):
+ def __init__(self, hidden_dim, input_dim, kernel_size=3, tuning_mode=None):
+ super(ConvGRU, self).__init__()
+ self.convz = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0)
+ self.convr = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0)
+ self.convq = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0)
+
+ def forward(self, h, cz, cr, cq, *x_list):
+ x = torch.cat(x_list, dim=1)
+ hx = torch.cat([h, x], dim=1)
+
+ z = torch.sigmoid((self.convz(hx) + cz))
+ r = torch.sigmoid((self.convr(hx) + cr))
+ q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq))
+
+ # z = torch.sigmoid((self.convz(hx) + cz).float())
+ # r = torch.sigmoid((self.convr(hx) + cr).float())
+ # q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq).float())
+
+ h = (1-z) * h + z * q
+ return h
+
+def pool2x(x):
+ return F.avg_pool2d(x, 3, stride=2, padding=1)
+
+def pool4x(x):
+ return F.avg_pool2d(x, 5, stride=4, padding=1)
+
+def interp(x, dest):
+ interp_args = {'mode': 'bilinear', 'align_corners': True}
+ return interpolate_float32(x, dest.shape[2:], **interp_args)
+
+class BasicMultiUpdateBlock(nn.Module):
+ def __init__(self, args, hidden_dims=[], out_dims=2, tuning_mode=None):
+ super().__init__()
+ self.args = args
+ self.n_gru_layers = args.model.decode_head.n_gru_layers # 3
+ self.n_downsample = args.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K)
+
+ # self.encoder = BasicMotionEncoder(args)
+ # encoder_output_dim = 128 # if there is corr volume
+ encoder_output_dim = 6 # no corr volume
+
+ self.gru08 = ConvGRU(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (self.n_gru_layers > 1), tuning_mode=tuning_mode)
+ self.gru16 = ConvGRU(hidden_dims[1], hidden_dims[0] * (self.n_gru_layers == 3) + hidden_dims[2], tuning_mode=tuning_mode)
+ self.gru32 = ConvGRU(hidden_dims[0], hidden_dims[1], tuning_mode=tuning_mode)
+ self.flow_head = FlowHead(hidden_dims[2], hidden_dim=2*hidden_dims[2], tuning_mode=tuning_mode)
+ factor = 2**self.n_downsample
+
+ self.mask = nn.Sequential(
+ Conv2dLoRA(hidden_dims[2], hidden_dims[2], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0),
+ nn.ReLU(inplace=True),
+ Conv2dLoRA(hidden_dims[2], (factor**2)*9, 1, padding=0, r = 8 if tuning_mode == 'lora' else 0))
+
+ def forward(self, net, inp, corr=None, flow=None, iter08=True, iter16=True, iter32=True, update=True):
+
+ if iter32:
+ net[2] = self.gru32(net[2], *(inp[2]), pool2x(net[1]))
+ if iter16:
+ if self.n_gru_layers > 2:
+ net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1]), interp(net[2], net[1]))
+ else:
+ net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1]))
+ if iter08:
+ if corr is not None:
+ motion_features = self.encoder(flow, corr)
+ else:
+ motion_features = flow
+ if self.n_gru_layers > 1:
+ net[0] = self.gru08(net[0], *(inp[0]), motion_features, interp(net[1], net[0]))
+ else:
+ net[0] = self.gru08(net[0], *(inp[0]), motion_features)
+
+ if not update:
+ return net
+
+ delta_flow = self.flow_head(net[0])
+
+ # scale mask to balence gradients
+ mask = .25 * self.mask(net[0])
+ return net, mask, delta_flow
+
+class LayerNorm2d(nn.LayerNorm):
+ def __init__(self, dim):
+ super(LayerNorm2d, self).__init__(dim)
+
+ def forward(self, x):
+ x = x.permute(0, 2, 3, 1).contiguous()
+ x = super(LayerNorm2d, self).forward(x)
+ x = x.permute(0, 3, 1, 2).contiguous()
+ return x
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1, tuning_mode=None):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = Conv2dLoRA(in_planes, planes, kernel_size=3, padding=1, stride=stride, r = 8 if tuning_mode == 'lora' else 0)
+ self.conv2 = Conv2dLoRA(planes, planes, kernel_size=3, padding=1, r = 8 if tuning_mode == 'lora' else 0)
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == 'group':
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == 'batch':
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == 'instance':
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == 'layer':
+ self.norm1 = LayerNorm2d(planes)
+ self.norm2 = LayerNorm2d(planes)
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = LayerNorm2d(planes)
+
+ elif norm_fn == 'none':
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.Sequential()
+
+ if stride == 1 and in_planes == planes:
+ self.downsample = None
+
+ else:
+ self.downsample = nn.Sequential(
+ Conv2dLoRA(in_planes, planes, kernel_size=1, stride=stride, r = 8 if tuning_mode == 'lora' else 0), self.norm3)
+
+ def forward(self, x):
+ y = x
+ y = self.conv1(y)
+ y = self.norm1(y)
+ y = self.relu(y)
+ y = self.conv2(y)
+ y = self.norm2(y)
+ y = self.relu(y)
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x+y)
+
+
+class ContextFeatureEncoder(nn.Module):
+ '''
+ Encoder features are used to:
+ 1. initialize the hidden state of the update operator
+ 2. and also injected into the GRU during each iteration of the update operator
+ '''
+ def __init__(self, in_dim, output_dim, tuning_mode=None):
+ '''
+ in_dim = [x4, x8, x16, x32]
+ output_dim = [hindden_dims, context_dims]
+ [[x4,x8,x16,x32],[x4,x8,x16,x32]]
+ '''
+ super().__init__()
+
+ output_list = []
+ for dim in output_dim:
+ conv_out = nn.Sequential(
+ ResidualBlock(in_dim[0], dim[0], 'layer', stride=1, tuning_mode=tuning_mode),
+ Conv2dLoRA(dim[0], dim[0], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0))
+ output_list.append(conv_out)
+
+ self.outputs04 = nn.ModuleList(output_list)
+
+ output_list = []
+ for dim in output_dim:
+ conv_out = nn.Sequential(
+ ResidualBlock(in_dim[1], dim[1], 'layer', stride=1, tuning_mode=tuning_mode),
+ Conv2dLoRA(dim[1], dim[1], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0))
+ output_list.append(conv_out)
+
+ self.outputs08 = nn.ModuleList(output_list)
+
+ output_list = []
+ for dim in output_dim:
+ conv_out = nn.Sequential(
+ ResidualBlock(in_dim[2], dim[2], 'layer', stride=1, tuning_mode=tuning_mode),
+ Conv2dLoRA(dim[2], dim[2], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0))
+ output_list.append(conv_out)
+
+ self.outputs16 = nn.ModuleList(output_list)
+
+ # output_list = []
+ # for dim in output_dim:
+ # conv_out = Conv2dLoRA(in_dim[3], dim[3], 3, padding=1)
+ # output_list.append(conv_out)
+
+ # self.outputs32 = nn.ModuleList(output_list)
+
+ def forward(self, encoder_features):
+ x_4, x_8, x_16, x_32 = encoder_features
+
+ outputs04 = [f(x_4) for f in self.outputs04]
+ outputs08 = [f(x_8) for f in self.outputs08]
+ outputs16 = [f(x_16)for f in self.outputs16]
+ # outputs32 = [f(x_32) for f in self.outputs32]
+
+ return (outputs04, outputs08, outputs16)
+
+class ConvBlock(nn.Module):
+ # reimplementation of DPT
+ def __init__(self, channels, tuning_mode=None):
+ super(ConvBlock, self).__init__()
+
+ self.act = nn.ReLU(inplace=True)
+ self.conv1 = Conv2dLoRA(
+ channels,
+ channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ r = 8 if tuning_mode == 'lora' else 0
+ )
+ self.conv2 = Conv2dLoRA(
+ channels,
+ channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ r = 8 if tuning_mode == 'lora' else 0
+ )
+
+ def forward(self, x):
+ out = self.act(x)
+ out = self.conv1(out)
+ out = self.act(out)
+ out = self.conv2(out)
+ return x + out
+
+class FuseBlock(nn.Module):
+ # reimplementation of DPT
+ def __init__(self, in_channels, out_channels, fuse=True, upsample=True, scale_factor=2, tuning_mode=None):
+ super(FuseBlock, self).__init__()
+
+ self.fuse = fuse
+ self.scale_factor = scale_factor
+ self.way_trunk = ConvBlock(in_channels, tuning_mode=tuning_mode)
+ if self.fuse:
+ self.way_branch = ConvBlock(in_channels, tuning_mode=tuning_mode)
+
+ self.out_conv = Conv2dLoRA(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ r = 8 if tuning_mode == 'lora' else 0
+ )
+ self.upsample = upsample
+
+ def forward(self, x1, x2=None):
+ if x2 is not None:
+ x2 = self.way_branch(x2)
+ x1 = x1 + x2
+
+ out = self.way_trunk(x1)
+
+ if self.upsample:
+ out = interpolate_float32(
+ out, scale_factor=self.scale_factor, mode="bilinear", align_corners=True
+ )
+ out = self.out_conv(out)
+ return out
+
+class Readout(nn.Module):
+ # From DPT
+ def __init__(self, in_features, use_cls_token=True, num_register_tokens=0, tuning_mode=None):
+ super(Readout, self).__init__()
+ self.use_cls_token = use_cls_token
+ if self.use_cls_token == True:
+ self.project_patch = LoRALinear(in_features, in_features, r = 8 if tuning_mode == 'lora' else 0)
+ self.project_learn = LoRALinear((1 + num_register_tokens) * in_features, in_features, bias=False, r = 8 if tuning_mode == 'lora' else 0)
+ self.act = nn.GELU()
+ else:
+ self.project = nn.Identity()
+
+ def forward(self, x):
+
+ if self.use_cls_token == True:
+ x_patch = self.project_patch(x[0])
+ x_learn = self.project_learn(x[1])
+ x_learn = x_learn.expand_as(x_patch).contiguous()
+ features = x_patch + x_learn
+ return self.act(features)
+ else:
+ return self.project(x)
+
+class Token2Feature(nn.Module):
+ # From DPT
+ def __init__(self, vit_channel, feature_channel, scale_factor, use_cls_token=True, num_register_tokens=0, tuning_mode=None):
+ super(Token2Feature, self).__init__()
+ self.scale_factor = scale_factor
+ self.readoper = Readout(in_features=vit_channel, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
+ if scale_factor > 1 and isinstance(scale_factor, int):
+ self.sample = ConvTranspose2dLoRA(r = 8 if tuning_mode == 'lora' else 0,
+ in_channels=vit_channel,
+ out_channels=feature_channel,
+ kernel_size=scale_factor,
+ stride=scale_factor,
+ padding=0,
+ )
+
+ elif scale_factor > 1:
+ self.sample = nn.Sequential(
+ # Upsample2(upscale=scale_factor),
+ # nn.Upsample(scale_factor=scale_factor),
+ Conv2dLoRA(r = 8 if tuning_mode == 'lora' else 0,
+ in_channels=vit_channel,
+ out_channels=feature_channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+
+ elif scale_factor < 1:
+ scale_factor = int(1.0 / scale_factor)
+ self.sample = Conv2dLoRA(r = 8 if tuning_mode == 'lora' else 0,
+ in_channels=vit_channel,
+ out_channels=feature_channel,
+ kernel_size=scale_factor+1,
+ stride=scale_factor,
+ padding=1,
+ )
+
+ else:
+ self.sample = nn.Identity()
+
+ def forward(self, x):
+ x = self.readoper(x)
+ #if use_cls_token == True:
+ x = x.permute(0, 3, 1, 2).contiguous()
+ if isinstance(self.scale_factor, float):
+ x = interpolate_float32(x.float(), scale_factor=self.scale_factor, mode='nearest')
+ x = self.sample(x)
+ return x
+
+class EncoderFeature(nn.Module):
+ def __init__(self, vit_channel, num_ch_dec=[256, 512, 1024, 1024], use_cls_token=True, num_register_tokens=0, tuning_mode=None):
+ super(EncoderFeature, self).__init__()
+ self.vit_channel = vit_channel
+ self.num_ch_dec = num_ch_dec
+
+ self.read_3 = Token2Feature(self.vit_channel, self.num_ch_dec[3], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
+ self.read_2 = Token2Feature(self.vit_channel, self.num_ch_dec[2], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
+ self.read_1 = Token2Feature(self.vit_channel, self.num_ch_dec[1], scale_factor=2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
+ self.read_0 = Token2Feature(self.vit_channel, self.num_ch_dec[0], scale_factor=7/2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode)
+
+ def forward(self, ref_feature):
+ x = self.read_3(ref_feature[3]) # 1/14
+ x2 = self.read_2(ref_feature[2]) # 1/14
+ x1 = self.read_1(ref_feature[1]) # 1/7
+ x0 = self.read_0(ref_feature[0]) # 1/4
+
+ return x, x2, x1, x0
+
+class DecoderFeature(nn.Module):
+ def __init__(self, vit_channel, num_ch_dec=[128, 256, 512, 1024, 1024], use_cls_token=True, tuning_mode=None):
+ super(DecoderFeature, self).__init__()
+ self.vit_channel = vit_channel
+ self.num_ch_dec = num_ch_dec
+
+ self.upconv_3 = FuseBlock(
+ self.num_ch_dec[4],
+ self.num_ch_dec[3],
+ fuse=False, upsample=False, tuning_mode=tuning_mode)
+
+ self.upconv_2 = FuseBlock(
+ self.num_ch_dec[3],
+ self.num_ch_dec[2],
+ tuning_mode=tuning_mode)
+
+ self.upconv_1 = FuseBlock(
+ self.num_ch_dec[2],
+ self.num_ch_dec[1] + 2,
+ scale_factor=7/4,
+ tuning_mode=tuning_mode)
+
+ # self.upconv_0 = FuseBlock(
+ # self.num_ch_dec[1],
+ # self.num_ch_dec[0] + 1,
+ # )
+
+ def forward(self, ref_feature):
+ x, x2, x1, x0 = ref_feature # 1/14 1/14 1/7 1/4
+
+ x = self.upconv_3(x) # 1/14
+ x = self.upconv_2(x, x2) # 1/7
+ x = self.upconv_1(x, x1) # 1/4
+ # x = self.upconv_0(x, x0) # 4/7
+ return x
+
+class RAFTDepthNormalDPT5(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.in_channels = cfg.model.decode_head.in_channels # [1024, 1024, 1024, 1024]
+ self.feature_channels = cfg.model.decode_head.feature_channels # [256, 512, 1024, 1024] [2/7, 1/7, 1/14, 1/14]
+ self.decoder_channels = cfg.model.decode_head.decoder_channels # [128, 256, 512, 1024, 1024] [-, 1/4, 1/7, 1/14, 1/14]
+ self.use_cls_token = cfg.model.decode_head.use_cls_token
+ self.up_scale = cfg.model.decode_head.up_scale
+ self.num_register_tokens = cfg.model.decode_head.num_register_tokens
+ self.min_val = cfg.data_basic.depth_normalize[0]
+ self.max_val = cfg.data_basic.depth_normalize[1]
+ self.regress_scale = 100.0\
+
+ try:
+ tuning_mode = cfg.model.decode_head.tuning_mode
+ except:
+ tuning_mode = None
+ self.tuning_mode = tuning_mode
+
+ self.hidden_dims = self.context_dims = cfg.model.decode_head.hidden_channels # [128, 128, 128, 128]
+ self.n_gru_layers = cfg.model.decode_head.n_gru_layers # 3
+ self.n_downsample = cfg.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K)
+ self.iters = cfg.model.decode_head.iters # 22
+ self.slow_fast_gru = cfg.model.decode_head.slow_fast_gru # True
+
+ self.num_depth_regressor_anchor = 256 # 512
+ self.used_res_channel = self.decoder_channels[1] # now, use 2/7 res
+ self.token2feature = EncoderFeature(self.in_channels[0], self.feature_channels, self.use_cls_token, self.num_register_tokens, tuning_mode=tuning_mode)
+ self.decoder_mono = DecoderFeature(self.in_channels, self.decoder_channels, tuning_mode=tuning_mode)
+ self.depth_regressor = nn.Sequential(
+ Conv2dLoRA(self.used_res_channel,
+ self.num_depth_regressor_anchor,
+ kernel_size=3,
+ padding=1, r = 8 if tuning_mode == 'lora' else 0),
+ # nn.BatchNorm2d(self.num_depth_regressor_anchor),
+ nn.ReLU(inplace=True),
+ Conv2dLoRA(self.num_depth_regressor_anchor,
+ self.num_depth_regressor_anchor,
+ kernel_size=1, r = 8 if tuning_mode == 'lora' else 0),
+ )
+ self.normal_predictor = nn.Sequential(
+ Conv2dLoRA(self.used_res_channel,
+ 128,
+ kernel_size=3,
+ padding=1, r = 8 if tuning_mode == 'lora' else 0,),
+ # nn.BatchNorm2d(128),
+ nn.ReLU(inplace=True),
+ Conv2dLoRA(128, 128, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), nn.ReLU(inplace=True),
+ Conv2dLoRA(128, 128, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), nn.ReLU(inplace=True),
+ Conv2dLoRA(128, 3, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0),
+ )
+
+ self.context_feature_encoder = ContextFeatureEncoder(self.feature_channels, [self.hidden_dims, self.context_dims], tuning_mode=tuning_mode)
+ self.context_zqr_convs = nn.ModuleList([Conv2dLoRA(self.context_dims[i], self.hidden_dims[i]*3, 3, padding=3//2, r = 8 if tuning_mode == 'lora' else 0) for i in range(self.n_gru_layers)])
+ self.update_block = BasicMultiUpdateBlock(cfg, hidden_dims=self.hidden_dims, out_dims=6, tuning_mode=tuning_mode)
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def get_bins(self, bins_num):
+ depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cuda")
+ depth_bins_vec = torch.exp(depth_bins_vec)
+ return depth_bins_vec
+
+ def register_depth_expectation_anchor(self, bins_num, B):
+ depth_bins_vec = self.get_bins(bins_num)
+ depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1)
+ self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False)
+
+ def clamp(self, x):
+ y = self.relu(x - self.min_val) + self.min_val
+ y = self.max_val - self.relu(self.max_val - y)
+ return y
+
+ def regress_depth(self, feature_map_d):
+ prob_feature = self.depth_regressor(feature_map_d)
+ prob = prob_feature.softmax(dim=1)
+ #prob = prob_feature.float().softmax(dim=1)
+
+ ## Error logging
+ if torch.isnan(prob).any():
+ print('prob_feat_nan!!!')
+ if torch.isinf(prob).any():
+ print('prob_feat_inf!!!')
+
+ # h = prob[0,:,0,0].cpu().numpy().reshape(-1)
+ # import matplotlib.pyplot as plt
+ # plt.bar(range(len(h)), h)
+ B = prob.shape[0]
+ if "depth_expectation_anchor" not in self._buffers:
+ self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, B)
+ d = compute_depth_expectation(
+ prob,
+ self.depth_expectation_anchor[:B, ...]).unsqueeze(1)
+
+ ## Error logging
+ if torch.isnan(d ).any():
+ print('d_nan!!!')
+ if torch.isinf(d ).any():
+ print('d_inf!!!')
+
+ return (self.clamp(d) - self.max_val)/ self.regress_scale, prob_feature
+
+ def pred_normal(self, feature_map, confidence):
+ normal_out = self.normal_predictor(feature_map)
+
+ ## Error logging
+ if torch.isnan(normal_out).any():
+ print('norm_nan!!!')
+ if torch.isinf(normal_out).any():
+ print('norm_feat_inf!!!')
+
+ return norm_normalize(torch.cat([normal_out, confidence], dim=1))
+ #return norm_normalize(torch.cat([normal_out, confidence], dim=1).float())
+
+ def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True):
+ y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device),
+ torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij')
+ meshgrid = torch.stack((x, y))
+ meshgrid = meshgrid.unsqueeze(0).repeat(batch, 1, 1, 1)
+ #self.register_buffer('meshgrid', meshgrid, persistent=False)
+ return meshgrid
+
+ def upsample_flow(self, flow, mask):
+ """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
+ N, D, H, W = flow.shape
+ factor = 2 ** self.n_downsample
+ mask = mask.view(N, 1, 9, factor, factor, H, W)
+ mask = torch.softmax(mask, dim=2)
+ #mask = torch.softmax(mask.float(), dim=2)
+
+ #up_flow = F.unfold(factor * flow, [3,3], padding=1)
+ up_flow = F.unfold(flow, [3,3], padding=1)
+ up_flow = up_flow.view(N, D, 9, 1, 1, H, W)
+
+ up_flow = torch.sum(mask * up_flow, dim=2)
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
+ return up_flow.reshape(N, D, factor*H, factor*W)
+
+ def initialize_flow(self, img):
+ """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
+ N, _, H, W = img.shape
+
+ coords0 = coords_grid(N, H, W).to(img.device)
+ coords1 = coords_grid(N, H, W).to(img.device)
+
+ return coords0, coords1
+
+ def upsample(self, x, scale_factor=2):
+ """Upsample input tensor by a factor of 2
+ """
+ return interpolate_float32(x, scale_factor=scale_factor*self.up_scale/8, mode="nearest")
+
+ def forward(self, vit_features, **kwargs):
+ ## read vit token to multi-scale features
+ B, H, W, _, _, num_register_tokens = vit_features[1]
+ vit_features = vit_features[0]
+
+ ## Error logging
+ if torch.isnan(vit_features[0]).any():
+ print('vit_feature_nan!!!')
+ if torch.isinf(vit_features[0]).any():
+ print('vit_feature_inf!!!')
+
+ if self.use_cls_token == True:
+ vit_features = [[ft[:, 1+num_register_tokens:, :].view(B, H, W, self.in_channels[0]), \
+ ft[:, 0:1+num_register_tokens, :].view(B, 1, 1, self.in_channels[0] * (1+num_register_tokens))] for ft in vit_features]
+ else:
+ vit_features = [ft.view(B, H, W, self.in_channels[0]) for ft in vit_features]
+ encoder_features = self.token2feature(vit_features) # 1/14, 1/14, 1/7, 1/4
+
+ ## Error logging
+ for en_ft in encoder_features:
+ if torch.isnan(en_ft).any():
+ print('decoder_feature_nan!!!')
+ print(en_ft.shape)
+ if torch.isinf(en_ft).any():
+ print('decoder_feature_inf!!!')
+ print(en_ft.shape)
+
+ ## decode features to init-depth (and confidence)
+ ref_feat= self.decoder_mono(encoder_features) # now, 1/4 for depth
+
+ ## Error logging
+ if torch.isnan(ref_feat).any():
+ print('ref_feat_nan!!!')
+ if torch.isinf(ref_feat).any():
+ print('ref_feat_inf!!!')
+
+ feature_map = ref_feat[:, :-2, :, :] # feature map share of depth and normal prediction
+ depth_confidence_map = ref_feat[:, -2:-1, :, :]
+ normal_confidence_map = ref_feat[:, -1:, :, :]
+ depth_pred, binmap = self.regress_depth(feature_map) # regress bin for depth
+ normal_pred = self.pred_normal(feature_map, normal_confidence_map) # mlp for normal
+
+ depth_init = torch.cat((depth_pred, depth_confidence_map, normal_pred), dim=1) # (N, 1+1+4, H, W)
+
+ ## encoder features to context-feature for init-hidden-state and contex-features
+ cnet_list = self.context_feature_encoder(encoder_features[::-1])
+ net_list = [torch.tanh(x[0]) for x in cnet_list] # x_4, x_8, x_16 of hidden state
+ inp_list = [torch.relu(x[1]) for x in cnet_list] # x_4, x_8, x_16 context features
+
+ # Rather than running the GRU's conv layers on the context features multiple times, we do it once at the beginning
+ inp_list = [list(conv(i).split(split_size=conv.out_channels//3, dim=1)) for i,conv in zip(inp_list, self.context_zqr_convs)]
+
+ coords0, coords1 = self.initialize_flow(net_list[0])
+ if depth_init is not None:
+ coords1 = coords1 + depth_init
+
+ if self.training:
+ low_resolution_init = [self.clamp(depth_init[:,:1] * self.regress_scale + self.max_val), depth_init[:,1:2], norm_normalize(depth_init[:,2:].clone())]
+ init_depth = upflow4(depth_init)
+ flow_predictions = [self.clamp(init_depth[:,:1] * self.regress_scale + self.max_val)]
+ conf_predictions = [init_depth[:,1:2]]
+ normal_outs = [norm_normalize(init_depth[:,2:].clone())]
+
+ else:
+ flow_predictions = []
+ conf_predictions = []
+ samples_pred_list = []
+ coord_list = []
+ normal_outs = []
+ low_resolution_init = []
+
+ for itr in range(self.iters):
+ # coords1 = coords1.detach()
+ flow = coords1 - coords0
+ if self.n_gru_layers == 3 and self.slow_fast_gru: # Update low-res GRU
+ net_list = self.update_block(net_list, inp_list, iter32=True, iter16=False, iter08=False, update=False)
+ if self.n_gru_layers >= 2 and self.slow_fast_gru:# Update low-res GRU and mid-res GRU
+ net_list = self.update_block(net_list, inp_list, iter32=self.n_gru_layers==3, iter16=True, iter08=False, update=False)
+ net_list, up_mask, delta_flow = self.update_block(net_list, inp_list, None, flow, iter32=self.n_gru_layers==3, iter16=self.n_gru_layers>=2)
+
+ # F(t+1) = F(t) + \Delta(t)
+ coords1 = coords1 + delta_flow
+
+ # We do not need to upsample or output intermediate results in test_mode
+ #if (not self.training) and itr < self.iters-1:
+ #continue
+
+ # upsample predictions
+ if up_mask is None:
+ flow_up = self.upsample(coords1-coords0, 4)
+ else:
+ flow_up = self.upsample_flow(coords1 - coords0, up_mask)
+ # flow_up = self.upsample(coords1-coords0, 4)
+
+ flow_predictions.append(self.clamp(flow_up[:,:1] * self.regress_scale + self.max_val))
+ conf_predictions.append(flow_up[:,1:2])
+ normal_outs.append(norm_normalize(flow_up[:,2:].clone()))
+
+ outputs=dict(
+ prediction=flow_predictions[-1],
+ predictions_list=flow_predictions,
+ confidence=conf_predictions[-1],
+ confidence_list=conf_predictions,
+ pred_logit=None,
+ # samples_pred_list=samples_pred_list,
+ # coord_list=coord_list,
+ prediction_normal=normal_outs[-1],
+ normal_out_list=normal_outs,
+ low_resolution_init=low_resolution_init,
+ )
+
+ return outputs
+
+
+if __name__ == "__main__":
+ try:
+ from mmcv.utils import Config
+ except:
+ from mmengine import Config
+ cfg = Config.fromfile('/cpfs01/shared/public/users/mu.hu/monodepth/mono/configs/RAFTDecoder/vit.raft.full2t.py')
+ cfg.model.decode_head.in_channels = [384, 384, 384, 384]
+ cfg.model.decode_head.feature_channels = [96, 192, 384, 768]
+ cfg.model.decode_head.decoder_channels = [48, 96, 192, 384, 384]
+ cfg.model.decode_head.hidden_channels = [48, 48, 48, 48, 48]
+ cfg.model.decode_head.up_scale = 7
+
+ # cfg.model.decode_head.use_cls_token = True
+ # vit_feature = [[torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \
+ # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \
+ # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \
+ # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()]]
+
+ cfg.model.decode_head.use_cls_token = True
+ cfg.model.decode_head.num_register_tokens = 4
+ vit_feature = [[torch.rand((2, (74 * 74) + 5, 384)).cuda(),\
+ torch.rand((2, (74 * 74) + 5, 384)).cuda(), \
+ torch.rand((2, (74 * 74) + 5, 384)).cuda(), \
+ torch.rand((2, (74 * 74) + 5, 384)).cuda()], (2, 74, 74, 1036, 1036, 4)]
+
+ decoder = RAFTDepthNormalDPT5(cfg).cuda()
+ output = decoder(vit_feature)
+ temp = 1
+
+
+
+
diff --git a/mono/model/decode_heads/__init__.py b/mono/model/decode_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..92381a5fc3dad0ca8009c1ab0a153ce6b107c634
--- /dev/null
+++ b/mono/model/decode_heads/__init__.py
@@ -0,0 +1,4 @@
+from .HourGlassDecoder import HourglassDecoder
+from .RAFTDepthNormalDPTDecoder5 import RAFTDepthNormalDPT5
+
+__all__=['HourglassDecoder', 'RAFTDepthNormalDPT5']
diff --git a/mono/model/model_pipelines/__base_model__.py b/mono/model/model_pipelines/__base_model__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d599c418b3d9677a195fe87d45bb31bf1068fbce
--- /dev/null
+++ b/mono/model/model_pipelines/__base_model__.py
@@ -0,0 +1,20 @@
+import torch
+import torch.nn as nn
+from mono.utils.comm import get_func
+
+
+class BaseDepthModel(nn.Module):
+ def __init__(self, cfg, **kwargs) -> None:
+ super(BaseDepthModel, self).__init__()
+ model_type = cfg.model.type
+ self.depth_model = get_func('mono.model.model_pipelines.' + model_type)(cfg)
+
+ def forward(self, data):
+ output = self.depth_model(**data)
+
+ return output['prediction'], output['confidence'], output
+
+ def inference(self, data):
+ with torch.no_grad():
+ pred_depth, confidence, _ = self.forward(data)
+ return pred_depth, confidence
\ No newline at end of file
diff --git a/mono/model/model_pipelines/__init__.py b/mono/model/model_pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b962a3f858573466e429219c4ad70951b545b637
--- /dev/null
+++ b/mono/model/model_pipelines/__init__.py
@@ -0,0 +1,6 @@
+
+from .dense_pipeline import DensePredModel
+from .__base_model__ import BaseDepthModel
+__all__ = [
+ 'DensePredModel', 'BaseDepthModel',
+]
\ No newline at end of file
diff --git a/mono/model/model_pipelines/dense_pipeline.py b/mono/model/model_pipelines/dense_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..1362a11b6b9d45e50795dd705906aa3f79ec4a9a
--- /dev/null
+++ b/mono/model/model_pipelines/dense_pipeline.py
@@ -0,0 +1,16 @@
+import torch
+import torch.nn as nn
+from mono.utils.comm import get_func
+
+class DensePredModel(nn.Module):
+ def __init__(self, cfg) -> None:
+ super(DensePredModel, self).__init__()
+
+ self.encoder = get_func('mono.model.' + cfg.model.backbone.prefix + cfg.model.backbone.type)(**cfg.model.backbone)
+ self.decoder = get_func('mono.model.' + cfg.model.decode_head.prefix + cfg.model.decode_head.type)(cfg)
+
+ def forward(self, input, **kwargs):
+ # [f_32, f_16, f_8, f_4]
+ features = self.encoder(input)
+ out = self.decoder(features, **kwargs)
+ return out
\ No newline at end of file
diff --git a/mono/model/monodepth_model.py b/mono/model/monodepth_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b58b7643ee43f84fd4e621e5b3b61b1f3f85564
--- /dev/null
+++ b/mono/model/monodepth_model.py
@@ -0,0 +1,37 @@
+import torch
+import torch.nn as nn
+from .model_pipelines.__base_model__ import BaseDepthModel
+
+class DepthModel(BaseDepthModel):
+ def __init__(self, cfg, **kwards):
+ super(DepthModel, self).__init__(cfg)
+ model_type = cfg.model.type
+
+ def inference(self, data):
+ with torch.no_grad():
+ pred_depth, confidence, output_dict = self.forward(data)
+ return pred_depth, confidence, output_dict
+
+def get_monodepth_model(
+ cfg : dict,
+ **kwargs
+ ) -> nn.Module:
+ # config depth model
+ model = DepthModel(cfg, **kwargs)
+ #model.init_weights(load_imagenet_model, imagenet_ckpt_fpath)
+ assert isinstance(model, nn.Module)
+ return model
+
+def get_configured_monodepth_model(
+ cfg: dict,
+ ) -> nn.Module:
+ """
+ Args:
+ @ configs: configures for the network.
+ @ load_imagenet_model: whether to initialize from ImageNet-pretrained model.
+ @ imagenet_ckpt_fpath: string representing path to file with weights to initialize model with.
+ Returns:
+ # model: depth model.
+ """
+ model = get_monodepth_model(cfg)
+ return model
diff --git a/mono/tools/test_scale_cano.py b/mono/tools/test_scale_cano.py
new file mode 100644
index 0000000000000000000000000000000000000000..684fb841a004833e27edd52192ad0821bf2d43af
--- /dev/null
+++ b/mono/tools/test_scale_cano.py
@@ -0,0 +1,158 @@
+import os
+import os.path as osp
+import cv2
+import time
+import sys
+CODE_SPACE=os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+sys.path.append(CODE_SPACE)
+import argparse
+import mmcv
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+try:
+ from mmcv.utils import Config, DictAction
+except:
+ from mmengine import Config, DictAction
+from datetime import timedelta
+import random
+import numpy as np
+from mono.utils.logger import setup_logger
+import glob
+from mono.utils.comm import init_env
+from mono.model.monodepth_model import get_configured_monodepth_model
+from mono.utils.running import load_ckpt
+from mono.utils.do_test import do_scalecano_test_with_custom_data
+from mono.utils.mldb import load_data_info, reset_ckpt_path
+from mono.utils.custom_data import load_from_annos, load_data
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Train a segmentor')
+ parser.add_argument('config', help='train config file path')
+ parser.add_argument('--show-dir', help='the dir to save logs and visualization results')
+ parser.add_argument('--load-from', help='the checkpoint file to load weights from')
+ parser.add_argument('--node_rank', type=int, default=0)
+ parser.add_argument('--nnodes', type=int, default=1, help='number of nodes')
+ parser.add_argument('--options', nargs='+', action=DictAction, help='custom options')
+ parser.add_argument('--launcher', choices=['None', 'pytorch', 'slurm', 'mpi', 'ror'], default='slurm', help='job launcher')
+ parser.add_argument('--test_data_path', default='None', type=str, help='the path of test data')
+ args = parser.parse_args()
+ return args
+
+def main(args):
+ os.chdir(CODE_SPACE)
+ cfg = Config.fromfile(args.config)
+
+ if args.options is not None:
+ cfg.merge_from_dict(args.options)
+
+ # show_dir is determined in this priority: CLI > segment in file > filename
+ if args.show_dir is not None:
+ # update configs according to CLI args if args.show_dir is not None
+ cfg.show_dir = args.show_dir
+ else:
+ # use condig filename + timestamp as default show_dir if args.show_dir is None
+ cfg.show_dir = osp.join('./show_dirs',
+ osp.splitext(osp.basename(args.config))[0],
+ args.timestamp)
+
+ # ckpt path
+ if args.load_from is None:
+ raise RuntimeError('Please set model path!')
+ cfg.load_from = args.load_from
+
+ # load data info
+ data_info = {}
+ load_data_info('data_info', data_info=data_info)
+ cfg.mldb_info = data_info
+ # update check point info
+ reset_ckpt_path(cfg.model, data_info)
+
+ # create show dir
+ os.makedirs(osp.abspath(cfg.show_dir), exist_ok=True)
+
+ # init the logger before other steps
+ cfg.log_file = osp.join(cfg.show_dir, f'{args.timestamp}.log')
+ logger = setup_logger(cfg.log_file)
+
+ # log some basic info
+ logger.info(f'Config:\n{cfg.pretty_text}')
+
+ # init distributed env dirst, since logger depends on the dist info
+ if args.launcher == 'None':
+ cfg.distributed = False
+ else:
+ cfg.distributed = True
+ init_env(args.launcher, cfg)
+ logger.info(f'Distributed training: {cfg.distributed}')
+
+ # dump config
+ cfg.dump(osp.join(cfg.show_dir, osp.basename(args.config)))
+ test_data_path = args.test_data_path
+ if not os.path.isabs(test_data_path):
+ test_data_path = osp.join(CODE_SPACE, test_data_path)
+
+ if 'json' in test_data_path:
+ test_data = load_from_annos(test_data_path)
+ else:
+ test_data = load_data(args.test_data_path)
+
+ if not cfg.distributed:
+ main_worker(0, cfg, args.launcher, test_data)
+ else:
+ # distributed training
+ if args.launcher == 'ror':
+ local_rank = cfg.dist_params.local_rank
+ main_worker(local_rank, cfg, args.launcher, test_data)
+ else:
+ mp.spawn(main_worker, nprocs=cfg.dist_params.num_gpus_per_node, args=(cfg, args.launcher, test_data))
+
+def main_worker(local_rank: int, cfg: dict, launcher: str, test_data: list):
+ if cfg.distributed:
+ cfg.dist_params.global_rank = cfg.dist_params.node_rank * cfg.dist_params.num_gpus_per_node + local_rank
+ cfg.dist_params.local_rank = local_rank
+
+ if launcher == 'ror':
+ init_torch_process_group(use_hvd=False)
+ else:
+ torch.cuda.set_device(local_rank)
+ default_timeout = timedelta(minutes=30)
+ dist.init_process_group(
+ backend=cfg.dist_params.backend,
+ init_method=cfg.dist_params.dist_url,
+ world_size=cfg.dist_params.world_size,
+ rank=cfg.dist_params.global_rank,
+ timeout=default_timeout)
+
+ logger = setup_logger(cfg.log_file)
+ # build model
+ model = get_configured_monodepth_model(cfg, )
+
+ # config distributed training
+ if cfg.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model.cuda(),
+ device_ids=[local_rank],
+ output_device=local_rank,
+ find_unused_parameters=True)
+ else:
+ model = torch.nn.DataParallel(model).cuda()
+
+ # load ckpt
+ model, _, _, _ = load_ckpt(cfg.load_from, model, strict_match=False)
+ model.eval()
+
+ do_scalecano_test_with_custom_data(
+ model,
+ cfg,
+ test_data,
+ logger,
+ cfg.distributed,
+ local_rank
+ )
+
+if __name__ == '__main__':
+ args = parse_args()
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+ args.timestamp = timestamp
+ main(args)
\ No newline at end of file
diff --git a/mono/utils/__init__.py b/mono/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/mono/utils/__init__.py
@@ -0,0 +1 @@
+
diff --git a/mono/utils/avg_meter.py b/mono/utils/avg_meter.py
new file mode 100644
index 0000000000000000000000000000000000000000..37ed9fffa7aa7be7eea094280102168993912f44
--- /dev/null
+++ b/mono/utils/avg_meter.py
@@ -0,0 +1,475 @@
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+import matplotlib.pyplot as plt
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self) -> None:
+ self.reset()
+
+ def reset(self) -> None:
+ self.val = np.longdouble(0.0)
+ self.avg = np.longdouble(0.0)
+ self.sum = np.longdouble(0.0)
+ self.count = np.longdouble(0.0)
+
+ def update(self, val, n: float = 1) -> None:
+ self.val = val
+ self.sum += val
+ self.count += n
+ self.avg = self.sum / (self.count + 1e-6)
+
+class MetricAverageMeter(AverageMeter):
+ """
+ An AverageMeter designed specifically for evaluating segmentation results.
+ """
+ def __init__(self, metrics: list) -> None:
+ """ Initialize object. """
+ # average meters for metrics
+ self.abs_rel = AverageMeter()
+ self.rmse = AverageMeter()
+ self.silog = AverageMeter()
+ self.delta1 = AverageMeter()
+ self.delta2 = AverageMeter()
+ self.delta3 = AverageMeter()
+
+ self.metrics = metrics
+
+ self.consistency = AverageMeter()
+ self.log10 = AverageMeter()
+ self.rmse_log = AverageMeter()
+ self.sq_rel = AverageMeter()
+
+ # normal
+ self.normal_mean = AverageMeter()
+ self.normal_rmse = AverageMeter()
+ self.normal_a1 = AverageMeter()
+ self.normal_a2 = AverageMeter()
+
+ self.normal_median = AverageMeter()
+ self.normal_a3 = AverageMeter()
+ self.normal_a4 = AverageMeter()
+ self.normal_a5 = AverageMeter()
+
+
+ def update_metrics_cpu(self,
+ pred: torch.Tensor,
+ target: torch.Tensor,
+ mask: torch.Tensor,):
+ """
+ Update metrics on cpu
+ """
+
+ assert pred.shape == target.shape
+
+ if len(pred.shape) == 3:
+ pred = pred[:, None, :, :]
+ target = target[:, None, :, :]
+ mask = mask[:, None, :, :]
+ elif len(pred.shape) == 2:
+ pred = pred[None, None, :, :]
+ target = target[None, None, :, :]
+ mask = mask[None, None, :, :]
+
+
+ # Absolute relative error
+ abs_rel_sum, valid_pics = get_absrel_err(pred, target, mask)
+ abs_rel_sum = abs_rel_sum.numpy()
+ valid_pics = valid_pics.numpy()
+ self.abs_rel.update(abs_rel_sum, valid_pics)
+
+ # squared relative error
+ sqrel_sum, _ = get_sqrel_err(pred, target, mask)
+ sqrel_sum = sqrel_sum.numpy()
+ self.sq_rel.update(sqrel_sum, valid_pics)
+
+ # root mean squared error
+ rmse_sum, _ = get_rmse_err(pred, target, mask)
+ rmse_sum = rmse_sum.numpy()
+ self.rmse.update(rmse_sum, valid_pics)
+
+ # log root mean squared error
+ log_rmse_sum, _ = get_rmse_log_err(pred, target, mask)
+ log_rmse_sum = log_rmse_sum.numpy()
+ self.rmse.update(log_rmse_sum, valid_pics)
+
+ # log10 error
+ log10_sum, _ = get_log10_err(pred, target, mask)
+ log10_sum = log10_sum.numpy()
+ self.rmse.update(log10_sum, valid_pics)
+
+ # scale-invariant root mean squared error in log space
+ silog_sum, _ = get_silog_err(pred, target, mask)
+ silog_sum = silog_sum.numpy()
+ self.silog.update(silog_sum, valid_pics)
+
+ # ratio error, delta1, ....
+ delta1_sum, delta2_sum, delta3_sum, _ = get_ratio_error(pred, target, mask)
+ delta1_sum = delta1_sum.numpy()
+ delta2_sum = delta2_sum.numpy()
+ delta3_sum = delta3_sum.numpy()
+
+ self.delta1.update(delta1_sum, valid_pics)
+ self.delta2.update(delta1_sum, valid_pics)
+ self.delta3.update(delta1_sum, valid_pics)
+
+
+ def update_metrics_gpu(
+ self,
+ pred: torch.Tensor,
+ target: torch.Tensor,
+ mask: torch.Tensor,
+ is_distributed: bool,
+ pred_next: torch.tensor = None,
+ pose_f1_to_f2: torch.tensor = None,
+ intrinsic: torch.tensor = None):
+ """
+ Update metric on GPU. It supports distributed processing. If multiple machines are employed, please
+ set 'is_distributed' as True.
+ """
+ assert pred.shape == target.shape
+
+ if len(pred.shape) == 3:
+ pred = pred[:, None, :, :]
+ target = target[:, None, :, :]
+ mask = mask[:, None, :, :]
+ elif len(pred.shape) == 2:
+ pred = pred[None, None, :, :]
+ target = target[None, None, :, :]
+ mask = mask[None, None, :, :]
+
+
+ # Absolute relative error
+ abs_rel_sum, valid_pics = get_absrel_err(pred, target, mask)
+ if is_distributed:
+ dist.all_reduce(abs_rel_sum), dist.all_reduce(valid_pics)
+ abs_rel_sum = abs_rel_sum.cpu().numpy()
+ valid_pics = int(valid_pics)
+ self.abs_rel.update(abs_rel_sum, valid_pics)
+
+ # root mean squared error
+ rmse_sum, _ = get_rmse_err(pred, target, mask)
+ if is_distributed:
+ dist.all_reduce(rmse_sum)
+ rmse_sum = rmse_sum.cpu().numpy()
+ self.rmse.update(rmse_sum, valid_pics)
+
+ # log root mean squared error
+ log_rmse_sum, _ = get_rmse_log_err(pred, target, mask)
+ if is_distributed:
+ dist.all_reduce(log_rmse_sum)
+ log_rmse_sum = log_rmse_sum.cpu().numpy()
+ self.rmse_log.update(log_rmse_sum, valid_pics)
+
+ # log10 error
+ log10_sum, _ = get_log10_err(pred, target, mask)
+ if is_distributed:
+ dist.all_reduce(log10_sum)
+ log10_sum = log10_sum.cpu().numpy()
+ self.log10.update(log10_sum, valid_pics)
+
+ # scale-invariant root mean squared error in log space
+ silog_sum, _ = get_silog_err(pred, target, mask)
+ if is_distributed:
+ dist.all_reduce(silog_sum)
+ silog_sum = silog_sum.cpu().numpy()
+ self.silog.update(silog_sum, valid_pics)
+
+ # ratio error, delta1, ....
+ delta1_sum, delta2_sum, delta3_sum, _ = get_ratio_err(pred, target, mask)
+ if is_distributed:
+ dist.all_reduce(delta1_sum), dist.all_reduce(delta2_sum), dist.all_reduce(delta3_sum)
+ delta1_sum = delta1_sum.cpu().numpy()
+ delta2_sum = delta2_sum.cpu().numpy()
+ delta3_sum = delta3_sum.cpu().numpy()
+
+ self.delta1.update(delta1_sum, valid_pics)
+ self.delta2.update(delta2_sum, valid_pics)
+ self.delta3.update(delta3_sum, valid_pics)
+
+ # video consistency error
+ # consistency_rel_sum, valid_warps = get_video_consistency_err(pred, pred_next, pose_f1_to_f2, intrinsic)
+ # if is_distributed:
+ # dist.all_reduce(consistency_rel_sum), dist.all_reduce(valid_warps)
+ # consistency_rel_sum = consistency_rel_sum.cpu().numpy()
+ # valid_warps = int(valid_warps)
+ # self.consistency.update(consistency_rel_sum, valid_warps)
+
+ ## for surface normal
+ def update_normal_metrics_gpu(
+ self,
+ pred: torch.Tensor, # (B, 3, H, W)
+ target: torch.Tensor, # (B, 3, H, W)
+ mask: torch.Tensor, # (B, 1, H, W)
+ is_distributed: bool,
+ ):
+ """
+ Update metric on GPU. It supports distributed processing. If multiple machines are employed, please
+ set 'is_distributed' as True.
+ """
+ assert pred.shape == target.shape
+
+ valid_pics = torch.sum(mask, dtype=torch.float32) + 1e-6
+
+ if valid_pics < 10:
+ return
+
+ mean_error = rmse_error = a1_error = a2_error = dist_node_cnt = valid_pics
+ normal_error = torch.cosine_similarity(pred, target, dim=1)
+ normal_error = torch.clamp(normal_error, min=-1.0, max=1.0)
+ angle_error = torch.acos(normal_error) * 180.0 / torch.pi
+ angle_error = angle_error[:, None, :, :]
+ angle_error = angle_error[mask]
+ # Calculation error
+ mean_error = angle_error.sum() / valid_pics
+ rmse_error = torch.sqrt( torch.sum(torch.square(angle_error)) / valid_pics )
+ median_error = angle_error.median()
+ a1_error = 100.0 * (torch.sum(angle_error < 5) / valid_pics)
+ a2_error = 100.0 * (torch.sum(angle_error < 7.5) / valid_pics)
+
+ a3_error = 100.0 * (torch.sum(angle_error < 11.25) / valid_pics)
+ a4_error = 100.0 * (torch.sum(angle_error < 22.5) / valid_pics)
+ a5_error = 100.0 * (torch.sum(angle_error < 30) / valid_pics)
+
+ # if valid_pics > 1e-5:
+ # If the current node gets data with valid normal
+ dist_node_cnt = (valid_pics - 1e-6) / valid_pics
+
+ if is_distributed:
+ dist.all_reduce(dist_node_cnt)
+ dist.all_reduce(mean_error)
+ dist.all_reduce(rmse_error)
+ dist.all_reduce(a1_error)
+ dist.all_reduce(a2_error)
+
+ dist.all_reduce(a3_error)
+ dist.all_reduce(a4_error)
+ dist.all_reduce(a5_error)
+
+ dist_node_cnt = dist_node_cnt.cpu().numpy()
+ self.normal_mean.update(mean_error.cpu().numpy(), dist_node_cnt)
+ self.normal_rmse.update(rmse_error.cpu().numpy(), dist_node_cnt)
+ self.normal_a1.update(a1_error.cpu().numpy(), dist_node_cnt)
+ self.normal_a2.update(a2_error.cpu().numpy(), dist_node_cnt)
+
+ self.normal_median.update(median_error.cpu().numpy(), dist_node_cnt)
+ self.normal_a3.update(a3_error.cpu().numpy(), dist_node_cnt)
+ self.normal_a4.update(a4_error.cpu().numpy(), dist_node_cnt)
+ self.normal_a5.update(a5_error.cpu().numpy(), dist_node_cnt)
+
+
+ def get_metrics(self,):
+ """
+ """
+ metrics_dict = {}
+ for metric in self.metrics:
+ metrics_dict[metric] = self.__getattribute__(metric).avg
+ return metrics_dict
+
+
+ def get_metrics(self,):
+ """
+ """
+ metrics_dict = {}
+ for metric in self.metrics:
+ metrics_dict[metric] = self.__getattribute__(metric).avg
+ return metrics_dict
+
+def get_absrel_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor,
+ ):
+ """
+ Computes absolute relative error.
+ Tasks preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred * mask
+
+ # Mean Absolute Relative Error
+ rel = torch.abs(t_m - p_m) / (t_m + 1e-10) # compute errors
+ abs_rel_sum = torch.sum(rel.reshape((b, c, -1)), dim=2) # [b, c]
+ num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c]
+ abs_err = abs_rel_sum / (num + 1e-10)
+ valid_pics = torch.sum(num > 0)
+ return torch.sum(abs_err), valid_pics
+
+def get_sqrel_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor,
+ ):
+ """
+ Computes squared relative error.
+ Tasks preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred * mask
+
+ # squared Relative Error
+ sq_rel = torch.abs(t_m - p_m) ** 2 / (t_m + 1e-10) # compute errors
+ sq_rel_sum = torch.sum(sq_rel.reshape((b, c, -1)), dim=2) # [b, c]
+ num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c]
+ sqrel_err = sq_rel_sum / (num + 1e-10)
+ valid_pics = torch.sum(num > 0)
+ return torch.sum(sqrel_err), valid_pics
+
+def get_log10_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor,
+ ):
+ """
+ Computes log10 error.
+ Tasks preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred * mask
+
+ diff_log = (torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)) * mask
+ log10_diff = torch.abs(diff_log)
+ log10_sum = torch.sum(log10_diff.reshape((b, c, -1)), dim=2) # [b, c]
+ num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c]
+ log10_err = log10_sum / (num + 1e-10)
+ valid_pics = torch.sum(num > 0)
+ return torch.sum(log10_err), valid_pics
+
+def get_rmse_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor,
+ ):
+ """
+ Computes rmse error.
+ Tasks preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred * mask
+
+ square = (t_m - p_m) ** 2
+ rmse_sum = torch.sum(square.reshape((b, c, -1)), dim=2) # [b, c]
+ num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c]
+ rmse = torch.sqrt(rmse_sum / (num + 1e-10))
+ valid_pics = torch.sum(num > 0)
+ return torch.sum(rmse), valid_pics
+
+def get_rmse_log_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor,
+ ):
+ """
+ Computes log rmse error.
+ Tasks preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred * mask
+
+ diff_log = (torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)) * mask
+ square = diff_log ** 2
+ rmse_log_sum = torch.sum(square.reshape((b, c, -1)), dim=2) # [b, c]
+ num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c]
+ rmse_log = torch.sqrt(rmse_log_sum / (num + 1e-10))
+ valid_pics = torch.sum(num > 0)
+ return torch.sum(rmse_log), valid_pics
+
+def get_silog_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor,
+ ):
+ """
+ Computes log rmse error.
+ Tasks preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred * mask
+
+ diff_log = (torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)) * mask
+ diff_log_sum = torch.sum(diff_log.reshape((b, c, -1)), dim=2) # [b, c]
+ diff_log_square = diff_log ** 2
+ diff_log_square_sum = torch.sum(diff_log_square.reshape((b, c, -1)), dim=2) # [b, c]
+ num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c]
+ silog = torch.sqrt(diff_log_square_sum / (num + 1e-10) - (diff_log_sum / (num + 1e-10)) ** 2)
+ valid_pics = torch.sum(num > 0)
+ return torch.sum(silog), valid_pics
+
+def get_ratio_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor,
+ ):
+ """
+ Computes the percentage of pixels for which the ratio of the two depth maps is less than a given threshold.
+ Tasks preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred
+
+ gt_pred = t_m / (p_m + 1e-10)
+ pred_gt = p_m / (t_m + 1e-10)
+ gt_pred = gt_pred.reshape((b, c, -1))
+ pred_gt = pred_gt.reshape((b, c, -1))
+ gt_pred_gt = torch.cat((gt_pred, pred_gt), axis=1)
+ ratio_max = torch.amax(gt_pred_gt, axis=1)
+
+ delta_1_sum = torch.sum((ratio_max < 1.25), dim=1) # [b, ]
+ delta_2_sum = torch.sum((ratio_max < 1.25 ** 2), dim=1) # [b, ]
+ delta_3_sum = torch.sum((ratio_max < 1.25 ** 3), dim=1) # [b, ]
+ num = torch.sum(mask.reshape((b, -1)), dim=1) # [b, ]
+
+ delta_1 = delta_1_sum / (num + 1e-10)
+ delta_2 = delta_2_sum / (num + 1e-10)
+ delta_3 = delta_3_sum / (num + 1e-10)
+ valid_pics = torch.sum(num > 0)
+
+ return torch.sum(delta_1), torch.sum(delta_2), torch.sum(delta_3), valid_pics
+
+
+if __name__ == '__main__':
+ cfg = ['abs_rel', 'delta1']
+ dam = MetricAverageMeter(cfg)
+
+ pred_depth = np.random.random([2, 480, 640])
+ gt_depth = np.random.random([2, 480, 640]) - 0.5
+ intrinsic = [[100, 100, 200, 200], [200, 200, 300, 300]]
+
+ pred = torch.from_numpy(pred_depth).cuda()
+ gt = torch.from_numpy(gt_depth).cuda()
+
+ mask = gt > 0
+ dam.update_metrics_gpu(pred, gt, mask, False)
+ eval_error = dam.get_metrics()
+ print(eval_error)
+
\ No newline at end of file
diff --git a/mono/utils/comm.py b/mono/utils/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..939e4e175c14563d5d13e77e6b56fd1a34668ebf
--- /dev/null
+++ b/mono/utils/comm.py
@@ -0,0 +1,322 @@
+import importlib
+import torch
+import torch.distributed as dist
+from .avg_meter import AverageMeter
+from collections import defaultdict, OrderedDict
+import os
+import socket
+from mmcv.utils import collect_env as collect_base_env
+try:
+ from mmcv.utils import get_git_hash
+except:
+ from mmengine.utils import get_git_hash
+#import mono.mmseg as mmseg
+# import mmseg
+import time
+import datetime
+import logging
+
+
+def main_process() -> bool:
+ return get_rank() == 0
+ #return not cfg.distributed or \
+ # (cfg.distributed and cfg.local_rank == 0)
+
+def get_world_size() -> int:
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+def get_rank() -> int:
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+def _find_free_port():
+ # refer to https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ # Binding to port 0 will cause the OS to find an available port for us
+ sock.bind(('', 0))
+ port = sock.getsockname()[1]
+ sock.close()
+ # NOTE: there is still a chance the port could be taken by other processes.
+ return port
+
+def _is_free_port(port):
+ ips = socket.gethostbyname_ex(socket.gethostname())[-1]
+ ips.append('localhost')
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ return all(s.connect_ex((ip, port)) != 0 for ip in ips)
+
+
+# def collect_env():
+# """Collect the information of the running environments."""
+# env_info = collect_base_env()
+# env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}'
+
+# return env_info
+
+def init_env(launcher, cfg):
+ """Initialize distributed training environment.
+ If argument ``cfg.dist_params.dist_url`` is specified as 'env://', then the master port will be system
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+ environment variable, then a default port ``29500`` will be used.
+ """
+ if launcher == 'slurm':
+ _init_dist_slurm(cfg)
+ elif launcher == 'ror':
+ _init_dist_ror(cfg)
+ elif launcher == 'None':
+ _init_none_dist(cfg)
+ else:
+ raise RuntimeError(f'{cfg.launcher} has not been supported!')
+
+def _init_none_dist(cfg):
+ cfg.dist_params.num_gpus_per_node = 1
+ cfg.dist_params.world_size = 1
+ cfg.dist_params.nnodes = 1
+ cfg.dist_params.node_rank = 0
+ cfg.dist_params.global_rank = 0
+ cfg.dist_params.local_rank = 0
+ os.environ["WORLD_SIZE"] = str(1)
+
+def _init_dist_ror(cfg):
+ from ac2.ror.comm import get_local_rank, get_world_rank, get_local_size, get_node_rank, get_world_size
+ cfg.dist_params.num_gpus_per_node = get_local_size()
+ cfg.dist_params.world_size = get_world_size()
+ cfg.dist_params.nnodes = (get_world_size()) // (get_local_size())
+ cfg.dist_params.node_rank = get_node_rank()
+ cfg.dist_params.global_rank = get_world_rank()
+ cfg.dist_params.local_rank = get_local_rank()
+ os.environ["WORLD_SIZE"] = str(get_world_size())
+
+
+def _init_dist_slurm(cfg):
+ if 'NNODES' not in os.environ:
+ os.environ['NNODES'] = str(cfg.dist_params.nnodes)
+ if 'NODE_RANK' not in os.environ:
+ os.environ['NODE_RANK'] = str(cfg.dist_params.node_rank)
+
+ #cfg.dist_params.
+ num_gpus = torch.cuda.device_count()
+ world_size = int(os.environ['NNODES']) * num_gpus
+ os.environ['WORLD_SIZE'] = str(world_size)
+
+ # config port
+ if 'MASTER_PORT' in os.environ:
+ master_port = str(os.environ['MASTER_PORT']) # use MASTER_PORT in the environment variable
+ else:
+ # if torch.distributed default port(29500) is available
+ # then use it, else find a free port
+ if _is_free_port(16500):
+ master_port = '16500'
+ else:
+ master_port = str(_find_free_port())
+ os.environ['MASTER_PORT'] = master_port
+
+ # config addr
+ if 'MASTER_ADDR' in os.environ:
+ master_addr = str(os.environ['MASTER_PORT']) # use MASTER_PORT in the environment variable
+ # elif cfg.dist_params.dist_url is not None:
+ # master_addr = ':'.join(str(cfg.dist_params.dist_url).split(':')[:2])
+ else:
+ master_addr = '127.0.0.1' #'tcp://127.0.0.1'
+ os.environ['MASTER_ADDR'] = master_addr
+
+ # set dist_url to 'env://'
+ cfg.dist_params.dist_url = 'env://' #f"{master_addr}:{master_port}"
+
+ cfg.dist_params.num_gpus_per_node = num_gpus
+ cfg.dist_params.world_size = world_size
+ cfg.dist_params.nnodes = int(os.environ['NNODES'])
+ cfg.dist_params.node_rank = int(os.environ['NODE_RANK'])
+
+ # if int(os.environ['NNODES']) > 1 and cfg.dist_params.dist_url.startswith("file://"):
+ # raise Warning("file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://")
+
+
+def get_func(func_name):
+ """
+ Helper to return a function object by name. func_name must identify
+ a function in this module or the path to a function relative to the base
+ module.
+ @ func_name: function name.
+ """
+ if func_name == '':
+ return None
+ try:
+ parts = func_name.split('.')
+ # Refers to a function in this module
+ if len(parts) == 1:
+ return globals()[parts[0]]
+ # Otherwise, assume we're referencing a module under modeling
+ module_name = '.'.join(parts[:-1])
+ module = importlib.import_module(module_name)
+ return getattr(module, parts[-1])
+ except:
+ raise RuntimeError(f'Failed to find function: {func_name}')
+
+class Timer(object):
+ """A simple timer."""
+
+ def __init__(self):
+ self.reset()
+
+ def tic(self):
+ # using time.time instead of time.clock because time time.clock
+ # does not normalize for multithreading
+ self.start_time = time.time()
+
+ def toc(self, average=True):
+ self.diff = time.time() - self.start_time
+ self.total_time += self.diff
+ self.calls += 1
+ self.average_time = self.total_time / self.calls
+ if average:
+ return self.average_time
+ else:
+ return self.diff
+
+ def reset(self):
+ self.total_time = 0.
+ self.calls = 0
+ self.start_time = 0.
+ self.diff = 0.
+ self.average_time = 0.
+
+class TrainingStats(object):
+ """Track vital training statistics."""
+ def __init__(self, log_period, tensorboard_logger=None):
+ self.log_period = log_period
+ self.tblogger = tensorboard_logger
+ self.tb_ignored_keys = ['iter', 'eta', 'epoch', 'time']
+ self.iter_timer = Timer()
+ # Window size for smoothing tracked values (with median filtering)
+ self.filter_size = log_period
+ def create_smoothed_value():
+ return AverageMeter()
+ self.smoothed_losses = defaultdict(create_smoothed_value)
+ #self.smoothed_metrics = defaultdict(create_smoothed_value)
+ #self.smoothed_total_loss = AverageMeter()
+
+
+ def IterTic(self):
+ self.iter_timer.tic()
+
+ def IterToc(self):
+ return self.iter_timer.toc(average=False)
+
+ def reset_iter_time(self):
+ self.iter_timer.reset()
+
+ def update_iter_stats(self, losses_dict):
+ """Update tracked iteration statistics."""
+ for k, v in losses_dict.items():
+ self.smoothed_losses[k].update(float(v), 1)
+
+ def log_iter_stats(self, cur_iter, optimizer, max_iters, val_err={}):
+ """Log the tracked statistics."""
+ if (cur_iter % self.log_period == 0):
+ stats = self.get_stats(cur_iter, optimizer, max_iters, val_err)
+ log_stats(stats)
+ if self.tblogger:
+ self.tb_log_stats(stats, cur_iter)
+ for k, v in self.smoothed_losses.items():
+ v.reset()
+
+ def tb_log_stats(self, stats, cur_iter):
+ """Log the tracked statistics to tensorboard"""
+ for k in stats:
+ # ignore some logs
+ if k not in self.tb_ignored_keys:
+ v = stats[k]
+ if isinstance(v, dict):
+ self.tb_log_stats(v, cur_iter)
+ else:
+ self.tblogger.add_scalar(k, v, cur_iter)
+
+
+ def get_stats(self, cur_iter, optimizer, max_iters, val_err = {}):
+ eta_seconds = self.iter_timer.average_time * (max_iters - cur_iter)
+
+ eta = str(datetime.timedelta(seconds=int(eta_seconds)))
+ stats = OrderedDict(
+ iter=cur_iter, # 1-indexed
+ time=self.iter_timer.average_time,
+ eta=eta,
+ )
+ optimizer_state_dict = optimizer.state_dict()
+ lr = {}
+ for i in range(len(optimizer_state_dict['param_groups'])):
+ lr_name = 'group%d_lr' % i
+ lr[lr_name] = optimizer_state_dict['param_groups'][i]['lr']
+
+ stats['lr'] = OrderedDict(lr)
+ for k, v in self.smoothed_losses.items():
+ stats[k] = v.avg
+
+ stats['val_err'] = OrderedDict(val_err)
+ stats['max_iters'] = max_iters
+ return stats
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Reduce the values in the dictionary from all processes so that process with rank
+ 0 has the reduced results.
+ Args:
+ @input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
+ @average (bool): whether to do average or sum
+ Returns:
+ a dict with the same keys as input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.reduce(values, dst=0)
+ if dist.get_rank() == 0 and average:
+ # only main process gets accumulated, so only divide by
+ # world_size in this case
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
+
+
+def log_stats(stats):
+ logger = logging.getLogger()
+ """Log training statistics to terminal"""
+ lines = "[Step %d/%d]\n" % (
+ stats['iter'], stats['max_iters'])
+
+ lines += "\t\tloss: %.3f, time: %.6f, eta: %s\n" % (
+ stats['total_loss'], stats['time'], stats['eta'])
+
+ # log loss
+ lines += "\t\t"
+ for k, v in stats.items():
+ if 'loss' in k.lower() and 'total_loss' not in k.lower():
+ lines += "%s: %.3f" % (k, v) + ", "
+ lines = lines[:-3]
+ lines += '\n'
+
+ # validate criteria
+ lines += "\t\tlast val err:" + ", ".join("%s: %.6f" % (k, v) for k, v in stats['val_err'].items()) + ", "
+ lines += '\n'
+
+ # lr in different groups
+ lines += "\t\t" + ", ".join("%s: %.8f" % (k, v) for k, v in stats['lr'].items())
+ lines += '\n'
+ logger.info(lines[:-1]) # remove last new linen_pxl
+
diff --git a/mono/utils/custom_data.py b/mono/utils/custom_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9fab47478bc471c51b5454cc15550079ebec21b
--- /dev/null
+++ b/mono/utils/custom_data.py
@@ -0,0 +1,34 @@
+import glob
+import os
+import json
+import cv2
+
+def load_from_annos(anno_path):
+ with open(anno_path, 'r') as f:
+ annos = json.load(f)['files']
+
+ datas = []
+ for i, anno in enumerate(annos):
+ rgb = anno['rgb']
+ depth = anno['depth'] if 'depth' in anno else None
+ depth_scale = anno['depth_scale'] if 'depth_scale' in anno else 1.0
+ intrinsic = anno['cam_in'] if 'cam_in' in anno else None
+ normal = anno['normal'] if 'normal' in anno else None
+
+ data_i = {
+ 'rgb': rgb,
+ 'depth': depth,
+ 'depth_scale': depth_scale,
+ 'intrinsic': intrinsic,
+ 'filename': os.path.basename(rgb),
+ 'folder': rgb.split('/')[-3],
+ 'normal': normal
+ }
+ datas.append(data_i)
+ return datas
+
+def load_data(path: str):
+ rgbs = glob.glob(path + '/*.jpg') + glob.glob(path + '/*.png')
+ #intrinsic = [835.8179931640625, 835.8179931640625, 961.5419921875, 566.8090209960938] #[721.53769, 721.53769, 609.5593, 172.854]
+ data = [{'rgb': i, 'depth': None, 'intrinsic': None, 'filename': os.path.basename(i), 'folder': i.split('/')[-3]} for i in rgbs]
+ return data
\ No newline at end of file
diff --git a/mono/utils/do_test.py b/mono/utils/do_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..89ee4afc9d6cd67ec491af6726c850347cafc099
--- /dev/null
+++ b/mono/utils/do_test.py
@@ -0,0 +1,364 @@
+import torch
+import torch.nn.functional as F
+import logging
+import os
+import os.path as osp
+from mono.utils.avg_meter import MetricAverageMeter
+from mono.utils.visualization import save_val_imgs, create_html, save_raw_imgs, save_normal_val_imgs
+import cv2
+from tqdm import tqdm
+import numpy as np
+from PIL import Image
+import matplotlib.pyplot as plt
+
+from mono.utils.unproj_pcd import reconstruct_pcd, save_point_cloud
+
+def to_cuda(data: dict):
+ for k, v in data.items():
+ if isinstance(v, torch.Tensor):
+ data[k] = v.cuda(non_blocking=True)
+ if isinstance(v, list) and len(v)>=1 and isinstance(v[0], torch.Tensor):
+ for i, l_i in enumerate(v):
+ data[k][i] = l_i.cuda(non_blocking=True)
+ return data
+
+def align_scale(pred: torch.tensor, target: torch.tensor):
+ mask = target > 0
+ if torch.sum(mask) > 10:
+ scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8)
+ else:
+ scale = 1
+ pred_scaled = pred * scale
+ return pred_scaled, scale
+
+def align_scale_shift(pred: torch.tensor, target: torch.tensor):
+ mask = target > 0
+ target_mask = target[mask].cpu().numpy()
+ pred_mask = pred[mask].cpu().numpy()
+ if torch.sum(mask) > 10:
+ scale, shift = np.polyfit(pred_mask, target_mask, deg=1)
+ if scale < 0:
+ scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8)
+ shift = 0
+ else:
+ scale = 1
+ shift = 0
+ pred = pred * scale + shift
+ return pred, scale
+
+def align_scale_shift_numpy(pred: np.array, target: np.array):
+ mask = target > 0
+ target_mask = target[mask]
+ pred_mask = pred[mask]
+ if np.sum(mask) > 10:
+ scale, shift = np.polyfit(pred_mask, target_mask, deg=1)
+ if scale < 0:
+ scale = np.median(target[mask]) / (np.median(pred[mask]) + 1e-8)
+ shift = 0
+ else:
+ scale = 1
+ shift = 0
+ pred = pred * scale + shift
+ return pred, scale
+
+
+def build_camera_model(H : int, W : int, intrinsics : list) -> np.array:
+ """
+ Encode the camera intrinsic parameters (focal length and principle point) to a 4-channel map.
+ """
+ fx, fy, u0, v0 = intrinsics
+ f = (fx + fy) / 2.0
+ # principle point location
+ x_row = np.arange(0, W).astype(np.float32)
+ x_row_center_norm = (x_row - u0) / W
+ x_center = np.tile(x_row_center_norm, (H, 1)) # [H, W]
+
+ y_col = np.arange(0, H).astype(np.float32)
+ y_col_center_norm = (y_col - v0) / H
+ y_center = np.tile(y_col_center_norm, (W, 1)).T # [H, W]
+
+ # FoV
+ fov_x = np.arctan(x_center / (f / W))
+ fov_y = np.arctan(y_center / (f / H))
+
+ cam_model = np.stack([x_center, y_center, fov_x, fov_y], axis=2)
+ return cam_model
+
+def resize_for_input(image, output_shape, intrinsic, canonical_shape, to_canonical_ratio):
+ """
+ Resize the input.
+ Resizing consists of two processed, i.e. 1) to the canonical space (adjust the camera model); 2) resize the image while the camera model holds. Thus the
+ label will be scaled with the resize factor.
+ """
+ padding = [123.675, 116.28, 103.53]
+ h, w, _ = image.shape
+ resize_ratio_h = output_shape[0] / canonical_shape[0]
+ resize_ratio_w = output_shape[1] / canonical_shape[1]
+ to_scale_ratio = min(resize_ratio_h, resize_ratio_w)
+
+ resize_ratio = to_canonical_ratio * to_scale_ratio
+
+ reshape_h = int(resize_ratio * h)
+ reshape_w = int(resize_ratio * w)
+
+ pad_h = max(output_shape[0] - reshape_h, 0)
+ pad_w = max(output_shape[1] - reshape_w, 0)
+ pad_h_half = int(pad_h / 2)
+ pad_w_half = int(pad_w / 2)
+
+ # resize
+ image = cv2.resize(image, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ # padding
+ image = cv2.copyMakeBorder(
+ image,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=padding)
+
+ # Resize, adjust principle point
+ intrinsic[2] = intrinsic[2] * to_scale_ratio
+ intrinsic[3] = intrinsic[3] * to_scale_ratio
+
+ cam_model = build_camera_model(reshape_h, reshape_w, intrinsic)
+ cam_model = cv2.copyMakeBorder(
+ cam_model,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=-1)
+
+ pad=[pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half]
+ label_scale_factor=1/to_scale_ratio
+ return image, cam_model, pad, label_scale_factor
+
+
+def get_prediction(
+ model: torch.nn.Module,
+ input: torch.tensor,
+ cam_model: torch.tensor,
+ pad_info: torch.tensor,
+ scale_info: torch.tensor,
+ gt_depth: torch.tensor,
+ normalize_scale: float,
+ ori_shape: list=[],
+):
+
+ data = dict(
+ input=input,
+ cam_model=cam_model,
+ )
+ pred_depth, confidence, output_dict = model.module.inference(data)
+ pred_depth = pred_depth
+ pred_depth = pred_depth.squeeze()
+ pred_depth = pred_depth[pad_info[0] : pred_depth.shape[0] - pad_info[1], pad_info[2] : pred_depth.shape[1] - pad_info[3]]
+ if gt_depth is not None:
+ resize_shape = gt_depth.shape
+ elif ori_shape != []:
+ resize_shape = ori_shape
+ else:
+ resize_shape = pred_depth.shape
+
+ pred_depth = torch.nn.functional.interpolate(pred_depth[None, None, :, :], resize_shape, mode='bilinear').squeeze() # to original size
+ pred_depth = pred_depth * normalize_scale / scale_info
+ if gt_depth is not None:
+ pred_depth_scale, scale = align_scale(pred_depth, gt_depth)
+ else:
+ pred_depth_scale = None
+ scale = None
+
+ return pred_depth, pred_depth_scale, scale, output_dict
+
+def transform_test_data_scalecano(rgb, intrinsic, data_basic):
+ """
+ Pre-process the input for forwarding. Employ `label scale canonical transformation.'
+ Args:
+ rgb: input rgb image. [H, W, 3]
+ intrinsic: camera intrinsic parameter, [fx, fy, u0, v0]
+ data_basic: predefined canonical space in configs.
+ """
+ canonical_space = data_basic['canonical_space']
+ forward_size = data_basic.crop_size
+ mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None]
+ std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None]
+
+ # BGR to RGB
+ rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
+
+ ori_h, ori_w, _ = rgb.shape
+ ori_focal = (intrinsic[0] + intrinsic[1]) / 2
+ canonical_focal = canonical_space['focal_length']
+
+ cano_label_scale_ratio = canonical_focal / ori_focal
+
+ canonical_intrinsic = [
+ intrinsic[0] * cano_label_scale_ratio,
+ intrinsic[1] * cano_label_scale_ratio,
+ intrinsic[2],
+ intrinsic[3],
+ ]
+
+ # resize
+ rgb, cam_model, pad, resize_label_scale_ratio = resize_for_input(rgb, forward_size, canonical_intrinsic, [ori_h, ori_w], 1.0)
+
+ # label scale factor
+ label_scale_factor = cano_label_scale_ratio * resize_label_scale_ratio
+
+ rgb = torch.from_numpy(rgb.transpose((2, 0, 1))).float()
+ rgb = torch.div((rgb - mean), std)
+ rgb = rgb[None, :, :, :].cuda()
+
+ cam_model = torch.from_numpy(cam_model.transpose((2, 0, 1))).float()
+ cam_model = cam_model[None, :, :, :].cuda()
+ cam_model_stacks = [
+ torch.nn.functional.interpolate(cam_model, size=(cam_model.shape[2]//i, cam_model.shape[3]//i), mode='bilinear', align_corners=False)
+ for i in [2, 4, 8, 16, 32]
+ ]
+ return rgb, cam_model_stacks, pad, label_scale_factor
+
+def do_scalecano_test_with_custom_data(
+ model: torch.nn.Module,
+ cfg: dict,
+ test_data: list,
+ logger: logging.RootLogger,
+ is_distributed: bool = True,
+ local_rank: int = 0,
+):
+
+ show_dir = cfg.show_dir
+ save_interval = 1
+ save_imgs_dir = show_dir + '/vis'
+ os.makedirs(save_imgs_dir, exist_ok=True)
+ save_pcd_dir = show_dir + '/pcd'
+ os.makedirs(save_pcd_dir, exist_ok=True)
+
+ normalize_scale = cfg.data_basic.depth_range[1]
+ dam = MetricAverageMeter(['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3'])
+ dam_median = MetricAverageMeter(['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3'])
+ dam_global = MetricAverageMeter(['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3'])
+
+ for i, an in tqdm(enumerate(test_data)):
+ #for i, an in enumerate(test_data):
+ print(an['rgb'])
+ rgb_origin = cv2.imread(an['rgb'])[:, :, ::-1].copy()
+ if an['depth'] is not None:
+ gt_depth = cv2.imread(an['depth'], -1)
+ gt_depth_scale = an['depth_scale']
+ gt_depth = gt_depth / gt_depth_scale
+ gt_depth_flag = True
+ else:
+ gt_depth = None
+ gt_depth_flag = False
+ intrinsic = an['intrinsic']
+ if intrinsic is None:
+ intrinsic = [1000.0, 1000.0, rgb_origin.shape[1]/2, rgb_origin.shape[0]/2]
+ # intrinsic = [542.0, 542.0, 963.706, 760.199]
+ print(intrinsic)
+ rgb_input, cam_models_stacks, pad, label_scale_factor = transform_test_data_scalecano(rgb_origin, intrinsic, cfg.data_basic)
+
+ pred_depth, pred_depth_scale, scale, output = get_prediction(
+ model = model,
+ input = rgb_input,
+ cam_model = cam_models_stacks,
+ pad_info = pad,
+ scale_info = label_scale_factor,
+ gt_depth = None,
+ normalize_scale = normalize_scale,
+ ori_shape=[rgb_origin.shape[0], rgb_origin.shape[1]],
+ )
+
+ pred_depth = (pred_depth > 0) * (pred_depth < 300) * pred_depth
+ if gt_depth_flag:
+
+ pred_depth = torch.nn.functional.interpolate(pred_depth[None, None, :, :], (gt_depth.shape[0], gt_depth.shape[1]), mode='bilinear').squeeze() # to original size
+
+ gt_depth = torch.from_numpy(gt_depth).cuda()
+
+ pred_depth_median = pred_depth * gt_depth[gt_depth != 0].median() / pred_depth[gt_depth != 0].median()
+ pred_global, _ = align_scale_shift(pred_depth, gt_depth)
+
+ mask = (gt_depth > 1e-8)
+ dam.update_metrics_gpu(pred_depth, gt_depth, mask, is_distributed)
+ dam_median.update_metrics_gpu(pred_depth_median, gt_depth, mask, is_distributed)
+ dam_global.update_metrics_gpu(pred_global, gt_depth, mask, is_distributed)
+ print(gt_depth[gt_depth != 0].median() / pred_depth[gt_depth != 0].median(), )
+
+ if i % save_interval == 0:
+ os.makedirs(osp.join(save_imgs_dir, an['folder']), exist_ok=True)
+ rgb_torch = torch.from_numpy(rgb_origin).to(pred_depth.device).permute(2, 0, 1)
+ mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None].to(rgb_torch.device)
+ std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None].to(rgb_torch.device)
+ rgb_torch = torch.div((rgb_torch - mean), std)
+
+ save_val_imgs(
+ i,
+ pred_depth,
+ gt_depth if gt_depth is not None else torch.ones_like(pred_depth, device=pred_depth.device),
+ rgb_torch,
+ osp.join(an['folder'], an['filename']),
+ save_imgs_dir,
+ )
+ #save_raw_imgs(pred_depth.detach().cpu().numpy(), rgb_torch, osp.join(an['folder'], an['filename']), save_imgs_dir, 1000.0)
+
+ # pcd
+ pred_depth = pred_depth.detach().cpu().numpy()
+ #pcd = reconstruct_pcd(pred_depth, intrinsic[0], intrinsic[1], intrinsic[2], intrinsic[3])
+ #os.makedirs(osp.join(save_pcd_dir, an['folder']), exist_ok=True)
+ #save_point_cloud(pcd.reshape((-1, 3)), rgb_origin.reshape(-1, 3), osp.join(save_pcd_dir, an['folder'], an['filename'][:-4]+'.ply'))
+
+ if an['intrinsic'] == None:
+ #for r in [0.9, 1.0, 1.1]:
+ for r in [1.0]:
+ #for f in [600, 800, 1000, 1250, 1500]:
+ for f in [1000]:
+ pcd = reconstruct_pcd(pred_depth, f * r, f * (2-r), intrinsic[2], intrinsic[3])
+ fstr = '_fx_' + str(int(f * r)) + '_fy_' + str(int(f * (2-r)))
+ os.makedirs(osp.join(save_pcd_dir, an['folder']), exist_ok=True)
+ save_point_cloud(pcd.reshape((-1, 3)), rgb_origin.reshape(-1, 3), osp.join(save_pcd_dir, an['folder'], an['filename'][:-4] + fstr +'.ply'))
+
+ if "normal_out_list" in output.keys():
+
+ normal_out_list = output['normal_out_list']
+ pred_normal = normal_out_list[0][:, :3, :, :] # (B, 3, H, W)
+ H, W = pred_normal.shape[2:]
+ pred_normal = pred_normal[:, :, pad[0]:H-pad[1], pad[2]:W-pad[3]]
+
+ gt_normal = None
+ #if gt_normal_flag:
+ if False:
+ pred_normal = torch.nn.functional.interpolate(pred_normal, size=gt_normal.shape[2:], mode='bilinear', align_corners=True)
+ gt_normal = cv2.imread(norm_path)
+ gt_normal = cv2.cvtColor(gt_normal, cv2.COLOR_BGR2RGB)
+ gt_normal = np.array(gt_normal).astype(np.uint8)
+ gt_normal = ((gt_normal.astype(np.float32) / 255.0) * 2.0) - 1.0
+ norm_valid_mask = (np.linalg.norm(gt_normal, axis=2, keepdims=True) > 0.5)
+ gt_normal = gt_normal * norm_valid_mask
+ gt_normal_mask = ~torch.all(gt_normal == 0, dim=1, keepdim=True)
+ dam.update_normal_metrics_gpu(pred_normal, gt_normal, gt_normal_mask, cfg.distributed)# save valiad normal
+
+ if i % save_interval == 0:
+ save_normal_val_imgs(iter,
+ pred_normal,
+ gt_normal if gt_normal is not None else torch.ones_like(pred_normal, device=pred_normal.device),
+ rgb_torch, # data['input'],
+ osp.join(an['folder'], 'normal_'+an['filename']),
+ save_imgs_dir,
+ )
+
+
+ #if gt_depth_flag:
+ if False:
+ eval_error = dam.get_metrics()
+ print('w/o match :', eval_error)
+
+ eval_error_median = dam_median.get_metrics()
+ print('median match :', eval_error_median)
+
+ eval_error_global = dam_global.get_metrics()
+ print('global match :', eval_error_global)
+ else:
+ print('missing gt_depth, only save visualizations...')
diff --git a/mono/utils/logger.py b/mono/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca48c613b2fdc5352b13ccb7d0bfdc1df5e3b531
--- /dev/null
+++ b/mono/utils/logger.py
@@ -0,0 +1,102 @@
+import atexit
+import logging
+import os
+import sys
+import time
+import torch
+from termcolor import colored
+
+__all__ = ["setup_logger", ]
+
+class _ColorfulFormatter(logging.Formatter):
+ def __init__(self, *args, **kwargs):
+ self._root_name = kwargs.pop("root_name") + "."
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
+ if len(self._abbrev_name):
+ self._abbrev_name = self._abbrev_name + "."
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
+
+ def formatMessage(self, record):
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
+ log = super(_ColorfulFormatter, self).formatMessage(record)
+ if record.levelno == logging.WARNING:
+ prefix = colored("WARNING", "red", attrs=["blink"])
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
+ else:
+ return log
+ return prefix + " " + log
+
+def setup_logger(
+ output=None, distributed_rank=0, *, name='metricdepth', color=True, abbrev_name=None
+):
+ """
+ Initialize the detectron2 logger and set its verbosity level to "DEBUG".
+ Args:
+ output (str): a file name or a directory to save log. If None, will not save log file.
+ If ends with ".txt" or ".log", assumed to be a file name.
+ Otherwise, logs will be saved to `output/log.txt`.
+ abbrev_name (str): an abbreviation of the module, to avoid log names in logs.
+ Set to "" not log the root module in logs.
+ By default, will abbreviate "detectron2" to "d2" and leave other
+ modules unchanged.
+ Returns:
+ logging.Logger: a logger
+ """
+ logger = logging.getLogger()
+ logger.setLevel(logging.INFO) # NOTE: if more detailed, change it to logging.DEBUG
+ logger.propagate = False
+
+ if abbrev_name is None:
+ abbrev_name = "d2"
+
+ plain_formatter = logging.Formatter(
+ "[%(asctime)s] %(name)s %(levelname)s %(message)s ", datefmt="%m/%d %H:%M:%S"
+ )
+ # stdout logging: master only
+ if distributed_rank == 0:
+ ch = logging.StreamHandler(stream=sys.stdout)
+ ch.setLevel(logging.INFO) # NOTE: if more detailed, change it to logging.DEBUG
+ if color:
+ formatter = _ColorfulFormatter(
+ colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
+ datefmt="%m/%d %H:%M:%S",
+ root_name=name,
+ abbrev_name=str(abbrev_name),
+ )
+ else:
+ formatter = plain_formatter
+ ch.setFormatter(formatter)
+ logger.addHandler(ch)
+
+ # file logging: all workers
+ if output is not None:
+ if output.endswith(".txt") or output.endswith(".log"):
+ filename = output
+ else:
+ filename = os.path.join(output, "log.txt")
+ if distributed_rank > 0:
+ filename = filename + ".rank{}".format(distributed_rank)
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+
+ fh = logging.StreamHandler(_cached_log_stream(filename))
+ fh.setLevel(logging.INFO) # NOTE: if more detailed, change it to logging.DEBUG
+ fh.setFormatter(plain_formatter)
+ logger.addHandler(fh)
+
+
+ return logger
+
+from iopath.common.file_io import PathManager as PathManagerBase
+
+
+PathManager = PathManagerBase()
+
+# cache the opened file object, so that different calls to 'setup_logger
+# with the same file name can safely write to the same file.
+def _cached_log_stream(filename):
+ # use 1K buffer if writting to cloud storage
+ io = PathManager.open(filename, "a", buffering=1024 if "://" in filename else -1)
+ atexit.register(io.close)
+ return io
+
\ No newline at end of file
diff --git a/mono/utils/mldb.py b/mono/utils/mldb.py
new file mode 100644
index 0000000000000000000000000000000000000000..d74ac53fd0302e2e954105bade52e6de4c18e2f6
--- /dev/null
+++ b/mono/utils/mldb.py
@@ -0,0 +1,34 @@
+from types import ModuleType
+import data_info
+
+def load_data_info(module_name, data_info={}, mldb_type='mldb_info', module=None):
+ if module is None:
+ module = globals().get(module_name, None)
+ if module:
+ for key, value in module.__dict__.items():
+ if not (key.startswith('__')) and not (key.startswith('_')):
+ if key == 'mldb_info':
+ data_info.update(value)
+ elif isinstance(value, ModuleType):
+ load_data_info(module_name + '.' + key, data_info, module=value)
+ else:
+ raise RuntimeError(f'Try to access "mldb_info", but cannot find {module_name} module.')
+
+def reset_ckpt_path(cfg, data_info):
+ if isinstance(cfg, dict):
+ for key in cfg.keys():
+ if key == 'backbone':
+ new_ckpt_path = data_info['checkpoint']['mldb_root'] + '/' + data_info['checkpoint'][cfg.backbone.type]
+ cfg.backbone.update(checkpoint=new_ckpt_path)
+ continue
+ elif isinstance(cfg.get(key), dict):
+ reset_ckpt_path(cfg.get(key), data_info)
+ else:
+ continue
+ else:
+ return
+
+if __name__ == '__main__':
+ mldb_info_tmp = {}
+ load_data_info('mldb_data_info', mldb_info_tmp)
+ print('results', mldb_info_tmp.keys())
\ No newline at end of file
diff --git a/mono/utils/pcd_filter.py b/mono/utils/pcd_filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d26314d806ea961f6bf09d1fb195bf5e364f181
--- /dev/null
+++ b/mono/utils/pcd_filter.py
@@ -0,0 +1,24 @@
+import open3d as o3d
+import numpy as np
+
+def downsample_and_filter(pcd_file):
+ pcd = o3d.io.read_point_cloud(pcd_file, max_bound_div = 750, neighbor_num = 8)
+ point_num = len(pcd.points)
+ if (point_num > 10000000):
+ voxel_down_pcd = o3d.geometry.PointCloud.uniform_down_sample(pcd, int(point_num / 10000000)+1)
+ else:
+ voxel_down_pcd = pcd
+ max_bound = voxel_down_pcd.get_max_bound()
+ ball_radius = np.linalg.norm(max_bound) / max_bound_div
+ pcd_filter, _ = voxel_down_pcd.remove_radius_outlier(neighbor_num, ball_radius)
+ print('filtered size', len(pcd_filter.points), 'pre size:', len(pcd.points))
+ o3d.io.write_point_cloud(pcd_file[:-4] + '_filtered.ply', pcd_filter)
+
+
+if __name__ == "__main__":
+ import os
+ dir_path = './data/demo_pcd'
+ for pcd_file in os.listdir(dir_path):
+ #if 'jonathan' in pcd_file: set max_bound_div to 300 and neighbot_num to 8
+ downsample_and_filter(os.path.join(dir_path, pcd_file))
+
\ No newline at end of file
diff --git a/mono/utils/running.py b/mono/utils/running.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a8b8d2c1f355717f46f784a28ac5f327c01dfc5
--- /dev/null
+++ b/mono/utils/running.py
@@ -0,0 +1,77 @@
+import os
+import torch
+import torch.nn as nn
+from mono.utils.comm import main_process
+import copy
+import inspect
+import logging
+import glob
+
+
+def load_ckpt(load_path, model, optimizer=None, scheduler=None, strict_match=True, loss_scaler=None):
+ """
+ Load the check point for resuming training or finetuning.
+ """
+ logger = logging.getLogger()
+ if os.path.isfile(load_path):
+ if main_process():
+ logger.info(f"Loading weight '{load_path}'")
+ checkpoint = torch.load(load_path, map_location="cpu")
+ ckpt_state_dict = checkpoint['model_state_dict']
+ model.module.load_state_dict(ckpt_state_dict, strict=strict_match)
+
+ if optimizer is not None:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ if scheduler is not None:
+ scheduler.load_state_dict(checkpoint['scheduler'])
+ if loss_scaler is not None and 'scaler' in checkpoint:
+ scheduler.load_state_dict(checkpoint['scaler'])
+ del ckpt_state_dict
+ del checkpoint
+ if main_process():
+ logger.info(f"Successfully loaded weight: '{load_path}'")
+ if scheduler is not None and optimizer is not None:
+ logger.info(f"Resume training from: '{load_path}'")
+ else:
+ if main_process():
+ raise RuntimeError(f"No weight found at '{load_path}'")
+ return model, optimizer, scheduler, loss_scaler
+
+
+def save_ckpt(cfg, model, optimizer, scheduler, curr_iter=0, curr_epoch=None, loss_scaler=None):
+ """
+ Save the model, optimizer, lr scheduler.
+ """
+ logger = logging.getLogger()
+
+ if 'IterBasedRunner' in cfg.runner.type:
+ max_iters = cfg.runner.max_iters
+ elif 'EpochBasedRunner' in cfg.runner.type:
+ max_iters = cfg.runner.max_epochs
+ else:
+ raise TypeError(f'{cfg.runner.type} is not supported')
+
+ ckpt = dict(
+ model_state_dict=model.module.state_dict(),
+ optimizer=optimizer.state_dict(),
+ max_iter=cfg.runner.max_iters if 'max_iters' in cfg.runner \
+ else cfg.runner.max_epochs,
+ scheduler=scheduler.state_dict(),
+ )
+
+ if loss_scaler is not None:
+ ckpt.update(dict(scaler=loss_scaler.state_dict()))
+
+ ckpt_dir = os.path.join(cfg.work_dir, 'ckpt')
+ os.makedirs(ckpt_dir, exist_ok=True)
+
+ save_name = os.path.join(ckpt_dir, 'step%08d.pth' %curr_iter)
+ saved_ckpts = glob.glob(ckpt_dir + '/step*.pth')
+ torch.save(ckpt, save_name)
+
+ # keep the last 8 ckpts
+ if len(saved_ckpts) > 20:
+ saved_ckpts.sort()
+ os.remove(saved_ckpts.pop(0))
+
+ logger.info(f'Save model: {save_name}')
diff --git a/mono/utils/transform.py b/mono/utils/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..2af94efe754d6f72325db6fdc170f30fbfb8c2fe
--- /dev/null
+++ b/mono/utils/transform.py
@@ -0,0 +1,408 @@
+import collections
+import cv2
+import math
+import numpy as np
+import numbers
+import random
+import torch
+
+import matplotlib
+import matplotlib.cm
+
+
+"""
+Provides a set of Pytorch transforms that use OpenCV instead of PIL (Pytorch default)
+for image manipulation.
+"""
+
+class Compose(object):
+ # Composes transforms: transforms.Compose([transforms.RandScale([0.5, 2.0]), transforms.ToTensor()])
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None):
+ for t in self.transforms:
+ images, labels, intrinsics, cam_models, other_labels, transform_paras = t(images, labels, intrinsics, cam_models, other_labels, transform_paras)
+ return images, labels, intrinsics, cam_models, other_labels, transform_paras
+
+
+class ToTensor(object):
+ # Converts numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W).
+ def __init__(self, **kwargs):
+ return
+ def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None):
+ if not isinstance(images, list) or not isinstance(labels, list) or not isinstance(intrinsics, list):
+ raise (RuntimeError("transform.ToTensor() only handle inputs/labels/intrinsics lists."))
+ if len(images) != len(intrinsics):
+ raise (RuntimeError("Numbers of images and intrinsics are not matched."))
+ if not isinstance(images[0], np.ndarray) or not isinstance(labels[0], np.ndarray):
+ raise (RuntimeError("transform.ToTensor() only handle np.ndarray for the input and label."
+ "[eg: data readed by cv2.imread()].\n"))
+ if not isinstance(intrinsics[0], list):
+ raise (RuntimeError("transform.ToTensor() only handle list for the camera intrinsics"))
+
+ if len(images[0].shape) > 3 or len(images[0].shape) < 2:
+ raise (RuntimeError("transform.ToTensor() only handle image(np.ndarray) with 3 dims or 2 dims.\n"))
+ if len(labels[0].shape) > 3 or len(labels[0].shape) < 2:
+ raise (RuntimeError("transform.ToTensor() only handle label(np.ndarray) with 3 dims or 2 dims.\n"))
+
+ if len(intrinsics[0]) >4 or len(intrinsics[0]) < 3:
+ raise (RuntimeError("transform.ToTensor() only handle intrinsic(list) with 3 sizes or 4 sizes.\n"))
+
+ for i, img in enumerate(images):
+ if len(img.shape) == 2:
+ img = np.expand_dims(img, axis=2)
+ images[i] = torch.from_numpy(img.transpose((2, 0, 1))).float()
+ for i, lab in enumerate(labels):
+ if len(lab.shape) == 2:
+ lab = np.expand_dims(lab, axis=0)
+ labels[i] = torch.from_numpy(lab).float()
+ for i, intrinsic in enumerate(intrinsics):
+ if len(intrinsic) == 3:
+ intrinsic = [intrinsic[0],] + intrinsic
+ intrinsics[i] = torch.tensor(intrinsic, dtype=torch.float)
+ if cam_models is not None:
+ for i, cam_model in enumerate(cam_models):
+ cam_models[i] = torch.from_numpy(cam_model.transpose((2, 0, 1))).float() if cam_model is not None else None
+ if other_labels is not None:
+ for i, lab in enumerate(other_labels):
+ if len(lab.shape) == 2:
+ lab = np.expand_dims(lab, axis=0)
+ other_labels[i] = torch.from_numpy(lab).float()
+ return images, labels, intrinsics, cam_models, other_labels, transform_paras
+
+
+class Normalize(object):
+ # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std
+ def __init__(self, mean, std=None, **kwargs):
+ if std is None:
+ assert len(mean) > 0
+ else:
+ assert len(mean) == len(std)
+ self.mean = torch.tensor(mean).float()[:, None, None]
+ self.std = torch.tensor(std).float()[:, None, None] if std is not None \
+ else torch.tensor([1.0, 1.0, 1.0]).float()[:, None, None]
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None):
+ # if self.std is None:
+ # # for t, m in zip(image, self.mean):
+ # # t.sub(m)
+ # image = image - self.mean
+ # if ref_images is not None:
+ # for i, ref_i in enumerate(ref_images):
+ # ref_images[i] = ref_i - self.mean
+ # else:
+ # # for t, m, s in zip(image, self.mean, self.std):
+ # # t.sub(m).div(s)
+ # image = (image - self.mean) / self.std
+ # if ref_images is not None:
+ # for i, ref_i in enumerate(ref_images):
+ # ref_images[i] = (ref_i - self.mean) / self.std
+ for i, img in enumerate(images):
+ img = torch.div((img - self.mean), self.std)
+ images[i] = img
+ return images, labels, intrinsics, cam_models, other_labels, transform_paras
+
+
+class LableScaleCanonical(object):
+ """
+ To solve the ambiguity observation for the mono branch, i.e. different focal length (object size) with the same depth, cameras are
+ mapped to a canonical space. To mimic this, we set the focal length to a canonical one and scale the depth value. NOTE: resize the image based on the ratio can also solve
+ Args:
+ images: list of RGB images.
+ labels: list of depth/disparity labels.
+ other labels: other labels, such as instance segmentations, semantic segmentations...
+ """
+ def __init__(self, **kwargs):
+ self.canonical_focal = kwargs['focal_length']
+
+ def _get_scale_ratio(self, intrinsic):
+ target_focal_x = intrinsic[0]
+ label_scale_ratio = self.canonical_focal / target_focal_x
+ pose_scale_ratio = 1.0
+ return label_scale_ratio, pose_scale_ratio
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None):
+ assert len(images[0].shape) == 3 and len(labels[0].shape) == 2
+ assert labels[0].dtype == np.float32
+
+ label_scale_ratio = None
+ pose_scale_ratio = None
+
+ for i in range(len(intrinsics)):
+ img_i = images[i]
+ label_i = labels[i] if i < len(labels) else None
+ intrinsic_i = intrinsics[i].copy()
+ cam_model_i = cam_models[i] if cam_models is not None and i < len(cam_models) else None
+
+ label_scale_ratio, pose_scale_ratio = self._get_scale_ratio(intrinsic_i)
+
+ # adjust the focal length, map the current camera to the canonical space
+ intrinsics[i] = [intrinsic_i[0] * label_scale_ratio, intrinsic_i[1] * label_scale_ratio, intrinsic_i[2], intrinsic_i[3]]
+
+ # scale the label to the canonical space
+ if label_i is not None:
+ labels[i] = label_i * label_scale_ratio
+
+ if cam_model_i is not None:
+ # As the focal length is adjusted (canonical focal length), the camera model should be re-built
+ ori_h, ori_w, _ = img_i.shape
+ cam_models[i] = build_camera_model(ori_h, ori_w, intrinsics[i])
+
+
+ if transform_paras is not None:
+ transform_paras.update(label_scale_factor=label_scale_ratio, focal_scale_factor=label_scale_ratio)
+
+ return images, labels, intrinsics, cam_models, other_labels, transform_paras
+
+
+class ResizeKeepRatio(object):
+ """
+ Resize and pad to a given size. Hold the aspect ratio.
+ This resizing assumes that the camera model remains unchanged.
+ Args:
+ resize_size: predefined output size.
+ """
+ def __init__(self, resize_size, padding=None, ignore_label=-1, **kwargs):
+ if isinstance(resize_size, int):
+ self.resize_h = resize_size
+ self.resize_w = resize_size
+ elif isinstance(resize_size, collections.Iterable) and len(resize_size) == 2 \
+ and isinstance(resize_size[0], int) and isinstance(resize_size[1], int) \
+ and resize_size[0] > 0 and resize_size[1] > 0:
+ self.resize_h = resize_size[0]
+ self.resize_w = resize_size[1]
+ else:
+ raise (RuntimeError("crop size error.\n"))
+ if padding is None:
+ self.padding = padding
+ elif isinstance(padding, list):
+ if all(isinstance(i, numbers.Number) for i in padding):
+ self.padding = padding
+ else:
+ raise (RuntimeError("padding in Crop() should be a number list\n"))
+ if len(padding) != 3:
+ raise (RuntimeError("padding channel is not equal with 3\n"))
+ else:
+ raise (RuntimeError("padding in Crop() should be a number list\n"))
+ if isinstance(ignore_label, int):
+ self.ignore_label = ignore_label
+ else:
+ raise (RuntimeError("ignore_label should be an integer number\n"))
+ # self.crop_size = kwargs['crop_size']
+ self.canonical_focal = kwargs['focal_length']
+
+ def main_data_transform(self, image, label, intrinsic, cam_model, resize_ratio, padding, to_scale_ratio):
+ """
+ Resize data first and then do the padding.
+ 'label' will be scaled.
+ """
+ h, w, _ = image.shape
+ reshape_h = int(resize_ratio * h)
+ reshape_w = int(resize_ratio * w)
+
+ pad_h, pad_w, pad_h_half, pad_w_half = padding
+
+ # resize
+ image = cv2.resize(image, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ # padding
+ image = cv2.copyMakeBorder(
+ image,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=self.padding)
+
+ if label is not None:
+ # label = cv2.resize(label, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+ label = resize_depth_preserve(label, (reshape_h, reshape_w))
+ label = cv2.copyMakeBorder(
+ label,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=self.ignore_label)
+ # scale the label
+ label = label / to_scale_ratio
+
+ # Resize, adjust principle point
+ if intrinsic is not None:
+ intrinsic[0] = intrinsic[0] * resize_ratio / to_scale_ratio
+ intrinsic[1] = intrinsic[1] * resize_ratio / to_scale_ratio
+ intrinsic[2] = intrinsic[2] * resize_ratio
+ intrinsic[3] = intrinsic[3] * resize_ratio
+
+ if cam_model is not None:
+ #cam_model = cv2.resize(cam_model, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ cam_model = build_camera_model(reshape_h, reshape_w, intrinsic)
+ cam_model = cv2.copyMakeBorder(
+ cam_model,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=self.ignore_label)
+
+ # Pad, adjust the principle point
+ if intrinsic is not None:
+ intrinsic[2] = intrinsic[2] + pad_w_half
+ intrinsic[3] = intrinsic[3] + pad_h_half
+ return image, label, intrinsic, cam_model
+
+ def get_label_scale_factor(self, image, intrinsic, resize_ratio):
+ ori_h, ori_w, _ = image.shape
+ # crop_h, crop_w = self.crop_size
+ ori_focal = intrinsic[0]
+
+ to_canonical_ratio = self.canonical_focal / ori_focal
+ to_scale_ratio = resize_ratio / to_canonical_ratio
+ return to_scale_ratio
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None):
+ target_h, target_w, _ = images[0].shape
+ resize_ratio_h = self.resize_h / target_h
+ resize_ratio_w = self.resize_w / target_w
+ resize_ratio = min(resize_ratio_h, resize_ratio_w)
+ reshape_h = int(resize_ratio * target_h)
+ reshape_w = int(resize_ratio * target_w)
+ pad_h = max(self.resize_h - reshape_h, 0)
+ pad_w = max(self.resize_w - reshape_w, 0)
+ pad_h_half = int(pad_h / 2)
+ pad_w_half = int(pad_w / 2)
+
+ pad_info = [pad_h, pad_w, pad_h_half, pad_w_half]
+ to_scale_ratio = self.get_label_scale_factor(images[0], intrinsics[0], resize_ratio)
+
+ for i in range(len(images)):
+ img = images[i]
+ label = labels[i] if i < len(labels) else None
+ intrinsic = intrinsics[i] if i < len(intrinsics) else None
+ cam_model = cam_models[i] if cam_models is not None and i < len(cam_models) else None
+ img, label, intrinsic, cam_model = self.main_data_transform(
+ img, label, intrinsic, cam_model, resize_ratio, pad_info, to_scale_ratio)
+ images[i] = img
+ if label is not None:
+ labels[i] = label
+ if intrinsic is not None:
+ intrinsics[i] = intrinsic
+ if cam_model is not None:
+ cam_models[i] = cam_model
+
+ if other_labels is not None:
+
+ for i, other_lab in enumerate(other_labels):
+ # resize
+ other_lab = cv2.resize(other_lab, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+ # pad
+ other_labels[i] = cv2.copyMakeBorder(
+ other_lab,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=self.ignore_label)
+
+ pad = [pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half]
+ if transform_paras is not None:
+ pad_old = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
+ new_pad = [pad_old[0] + pad[0], pad_old[1] + pad[1], pad_old[2] + pad[2], pad_old[3] + pad[3]]
+ transform_paras.update(dict(pad=new_pad))
+ if 'label_scale_factor' in transform_paras:
+ transform_paras['label_scale_factor'] = transform_paras['label_scale_factor'] * 1.0 / to_scale_ratio
+ else:
+ transform_paras.update(label_scale_factor=1.0/to_scale_ratio)
+ return images, labels, intrinsics, cam_models, other_labels, transform_paras
+
+
+class BGR2RGB(object):
+ # Converts image from BGR order to RGB order, for model initialized from Pytorch
+ def __init__(self, **kwargs):
+ return
+ def __call__(self, images, labels, intrinsics, cam_models=None,other_labels=None, transform_paras=None):
+ for i, img in enumerate(images):
+ images[i] = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return images, labels, intrinsics, cam_models, other_labels, transform_paras
+
+
+def resize_depth_preserve(depth, shape):
+ """
+ Resizes depth map preserving all valid depth pixels
+ Multiple downsampled points can be assigned to the same pixel.
+
+ Parameters
+ ----------
+ depth : np.array [h,w]
+ Depth map
+ shape : tuple (H,W)
+ Output shape
+
+ Returns
+ -------
+ depth : np.array [H,W,1]
+ Resized depth map
+ """
+ # Store dimensions and reshapes to single column
+ depth = np.squeeze(depth)
+ h, w = depth.shape
+ x = depth.reshape(-1)
+ # Create coordinate grid
+ uv = np.mgrid[:h, :w].transpose(1, 2, 0).reshape(-1, 2)
+ # Filters valid points
+ idx = x > 0
+ crd, val = uv[idx], x[idx]
+ # Downsamples coordinates
+ crd[:, 0] = (crd[:, 0] * (shape[0] / h) + 0.5).astype(np.int32)
+ crd[:, 1] = (crd[:, 1] * (shape[1] / w) + 0.5).astype(np.int32)
+ # Filters points inside image
+ idx = (crd[:, 0] < shape[0]) & (crd[:, 1] < shape[1])
+ crd, val = crd[idx], val[idx]
+ # Creates downsampled depth image and assigns points
+ depth = np.zeros(shape)
+ depth[crd[:, 0], crd[:, 1]] = val
+ # Return resized depth map
+ return depth
+
+
+def build_camera_model(H : int, W : int, intrinsics : list) -> np.array:
+ """
+ Encode the camera intrinsic parameters (focal length and principle point) to a 4-channel map.
+ """
+ fx, fy, u0, v0 = intrinsics
+ f = (fx + fy) / 2.0
+ # principle point location
+ x_row = np.arange(0, W).astype(np.float32)
+ x_row_center_norm = (x_row - u0) / W
+ x_center = np.tile(x_row_center_norm, (H, 1)) # [H, W]
+
+ y_col = np.arange(0, H).astype(np.float32)
+ y_col_center_norm = (y_col - v0) / H
+ y_center = np.tile(y_col_center_norm, (W, 1)).T
+
+ # FoV
+ fov_x = np.arctan(x_center / (f / W))
+ fov_y = np.arctan(y_center/ (f / H))
+
+ cam_model = np.stack([x_center, y_center, fov_x, fov_y], axis=2)
+ return cam_model
+
+def gray_to_colormap(img, cmap='rainbow'):
+ """
+ Transfer gray map to matplotlib colormap
+ """
+ assert img.ndim == 2
+
+ img[img<0] = 0
+ mask_invalid = img < 1e-10
+ img = img / (img.max() + 1e-8)
+ norm = matplotlib.colors.Normalize(vmin=0, vmax=1.1)
+ cmap_m = matplotlib.cm.get_cmap(cmap)
+ map = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap_m)
+ colormap = (map.to_rgba(img)[:, :, :3] * 255).astype(np.uint8)
+ colormap[mask_invalid] = 0
+ return colormap
\ No newline at end of file
diff --git a/mono/utils/unproj_pcd.py b/mono/utils/unproj_pcd.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0986d482a2ec68be1dd65719adec662272b833c
--- /dev/null
+++ b/mono/utils/unproj_pcd.py
@@ -0,0 +1,88 @@
+import numpy as np
+import torch
+from plyfile import PlyData, PlyElement
+import cv2
+
+
+def get_pcd_base(H, W, u0, v0, fx, fy):
+ x_row = np.arange(0, W)
+ x = np.tile(x_row, (H, 1))
+ x = x.astype(np.float32)
+ u_m_u0 = x - u0
+
+ y_col = np.arange(0, H) # y_col = np.arange(0, height)
+ y = np.tile(y_col, (W, 1)).T
+ y = y.astype(np.float32)
+ v_m_v0 = y - v0
+
+ x = u_m_u0 / fx
+ y = v_m_v0 / fy
+ z = np.ones_like(x)
+ pw = np.stack([x, y, z], axis=2) # [h, w, c]
+ return pw
+
+
+def reconstruct_pcd(depth, fx, fy, u0, v0, pcd_base=None, mask=None):
+ if type(depth) == torch.__name__:
+ depth = depth.cpu().numpy().squeeze()
+ depth = cv2.medianBlur(depth, 5)
+ if pcd_base is None:
+ H, W = depth.shape
+ pcd_base = get_pcd_base(H, W, u0, v0, fx, fy)
+ pcd = depth[:, :, None] * pcd_base
+ if mask:
+ pcd[mask] = 0
+ return pcd
+
+
+def save_point_cloud(pcd, rgb, filename, binary=True):
+ """Save an RGB point cloud as a PLY file.
+ :paras
+ @pcd: Nx3 matrix, the XYZ coordinates
+ @rgb: Nx3 matrix, the rgb colors for each 3D point
+ """
+ assert pcd.shape[0] == rgb.shape[0]
+
+ if rgb is None:
+ gray_concat = np.tile(np.array([128], dtype=np.uint8),
+ (pcd.shape[0], 3))
+ points_3d = np.hstack((pcd, gray_concat))
+ else:
+ points_3d = np.hstack((pcd, rgb))
+ python_types = (float, float, float, int, int, int)
+ npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'),
+ ('green', 'u1'), ('blue', 'u1')]
+ if binary is True:
+ # Format into Numpy structured array
+ vertices = []
+ for row_idx in range(points_3d.shape[0]):
+ cur_point = points_3d[row_idx]
+ vertices.append(
+ tuple(
+ dtype(point)
+ for dtype, point in zip(python_types, cur_point)))
+ vertices_array = np.array(vertices, dtype=npy_types)
+ el = PlyElement.describe(vertices_array, 'vertex')
+
+ # write
+ PlyData([el]).write(filename)
+ else:
+ x = np.squeeze(points_3d[:, 0])
+ y = np.squeeze(points_3d[:, 1])
+ z = np.squeeze(points_3d[:, 2])
+ r = np.squeeze(points_3d[:, 3])
+ g = np.squeeze(points_3d[:, 4])
+ b = np.squeeze(points_3d[:, 5])
+
+ ply_head = 'ply\n' \
+ 'format ascii 1.0\n' \
+ 'element vertex %d\n' \
+ 'property float x\n' \
+ 'property float y\n' \
+ 'property float z\n' \
+ 'property uchar red\n' \
+ 'property uchar green\n' \
+ 'property uchar blue\n' \
+ 'end_header' % r.shape[0]
+ # ---- Save ply data to disk
+ np.savetxt(filename, np.column_stack[x, y, z, r, g, b], fmt='%f %f %f %d %d %d', header=ply_head, comments='')
\ No newline at end of file
diff --git a/mono/utils/visualization.py b/mono/utils/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..07275030c48aeea062c0041b11ba60d911c14a3f
--- /dev/null
+++ b/mono/utils/visualization.py
@@ -0,0 +1,140 @@
+import matplotlib.pyplot as plt
+import os, cv2
+import numpy as np
+from mono.utils.transform import gray_to_colormap
+import shutil
+import glob
+from mono.utils.running import main_process
+import torch
+from html4vision import Col, imagetable
+
+def save_raw_imgs(
+ pred: torch.tensor,
+ rgb: torch.tensor,
+ filename: str,
+ save_dir: str,
+ scale: float=200.0,
+ target: torch.tensor=None,
+ ):
+ """
+ Save raw GT, predictions, RGB in the same file.
+ """
+ cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_rgb.jpg'), rgb)
+ cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_d.png'), (pred*scale).astype(np.uint16))
+ if target is not None:
+ cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_gt.png'), (target*scale).astype(np.uint16))
+
+
+def save_val_imgs(
+ iter: int,
+ pred: torch.tensor,
+ target: torch.tensor,
+ rgb: torch.tensor,
+ filename: str,
+ save_dir: str,
+ tb_logger=None
+ ):
+ """
+ Save GT, predictions, RGB in the same file.
+ """
+ rgb, pred_scale, target_scale, pred_color, target_color = get_data_for_log(pred, target, rgb)
+ rgb = rgb.transpose((1, 2, 0))
+ cat_img = np.concatenate([rgb, pred_color, target_color], axis=0)
+ plt.imsave(os.path.join(save_dir, filename[:-4]+'_merge.jpg'), cat_img)
+
+ # save to tensorboard
+ if tb_logger is not None:
+ tb_logger.add_image(f'{filename[:-4]}_merge.jpg', cat_img.transpose((2, 0, 1)), iter)
+
+def save_normal_val_imgs(
+ iter: int,
+ pred: torch.tensor,
+ targ: torch.tensor,
+ rgb: torch.tensor,
+ filename: str,
+ save_dir: str,
+ tb_logger=None,
+ mask=None,
+ ):
+ """
+ Save GT, predictions, RGB in the same file.
+ """
+ mean = np.array([123.675, 116.28, 103.53])[np.newaxis, np.newaxis, :]
+ std= np.array([58.395, 57.12, 57.375])[np.newaxis, np.newaxis, :]
+ pred = pred.squeeze()
+ targ = targ.squeeze()
+ rgb = rgb.squeeze()
+
+ if pred.size(0) == 3:
+ pred = pred.permute(1,2,0)
+ if targ.size(0) == 3:
+ targ = targ.permute(1,2,0)
+ if rgb.size(0) == 3:
+ rgb = rgb.permute(1,2,0)
+
+ pred_color = vis_surface_normal(pred, mask)
+ targ_color = vis_surface_normal(targ, mask)
+ rgb_color = ((rgb.cpu().numpy() * std) + mean).astype(np.uint8)
+
+ try:
+ cat_img = np.concatenate([rgb_color, pred_color, targ_color], axis=0)
+ except:
+ pred_color = cv2.resize(pred_color, (rgb.shape[1], rgb.shape[0]))
+ targ_color = cv2.resize(targ_color, (rgb.shape[1], rgb.shape[0]))
+ cat_img = np.concatenate([rgb_color, pred_color, targ_color], axis=0)
+
+ plt.imsave(os.path.join(save_dir, filename[:-4]+'_merge.jpg'), cat_img)
+ # cv2.imwrite(os.path.join(save_dir, filename[:-4]+'.jpg'), pred_color)
+ # save to tensorboard
+ if tb_logger is not None:
+ tb_logger.add_image(f'{filename[:-4]}_merge.jpg', cat_img.transpose((2, 0, 1)), iter)
+
+def get_data_for_log(pred: torch.tensor, target: torch.tensor, rgb: torch.tensor):
+ mean = np.array([123.675, 116.28, 103.53])[:, np.newaxis, np.newaxis]
+ std= np.array([58.395, 57.12, 57.375])[:, np.newaxis, np.newaxis]
+
+ pred = pred.squeeze().cpu().numpy()
+ target = target.squeeze().cpu().numpy()
+ rgb = rgb.squeeze().cpu().numpy()
+
+ pred[pred<0] = 0
+ target[target<0] = 0
+ max_scale = max(pred.max(), target.max())
+ pred_scale = (pred/max_scale * 10000).astype(np.uint16)
+ target_scale = (target/max_scale * 10000).astype(np.uint16)
+ pred_color = gray_to_colormap(pred)
+ target_color = gray_to_colormap(target)
+ pred_color = cv2.resize(pred_color, (rgb.shape[2], rgb.shape[1]))
+ target_color = cv2.resize(target_color, (rgb.shape[2], rgb.shape[1]))
+
+ rgb = ((rgb * std) + mean).astype(np.uint8)
+ return rgb, pred_scale, target_scale, pred_color, target_color
+
+
+def create_html(name2path, save_path='index.html', size=(256, 384)):
+ # table description
+ cols = []
+ for k, v in name2path.items():
+ col_i = Col('img', k, v) # specify image content for column
+ cols.append(col_i)
+ # html table generation
+ imagetable(cols, out_file=save_path, imsize=size)
+
+def vis_surface_normal(normal: torch.tensor, mask: torch.tensor=None) -> np.array:
+ """
+ Visualize surface normal. Transfer surface normal value from [-1, 1] to [0, 255]
+ Aargs:
+ normal (torch.tensor, [h, w, 3]): surface normal
+ mask (torch.tensor, [h, w]): valid masks
+ """
+ normal = normal.cpu().numpy().squeeze()
+ n_img_L2 = np.sqrt(np.sum(normal ** 2, axis=2, keepdims=True))
+ n_img_norm = normal / (n_img_L2 + 1e-8)
+ normal_vis = n_img_norm * 127
+ normal_vis += 128
+ normal_vis = normal_vis.astype(np.uint8)
+ if mask is not None:
+ mask = mask.cpu().numpy().squeeze()
+ normal_vis[~mask] = 0
+ return normal_vis
+
diff --git a/requirements_v1.txt b/requirements_v1.txt
new file mode 100644
index 0000000000000000000000000000000000000000..faf9a32bd1a19b92296b5c3f8eaab88b36ba425c
--- /dev/null
+++ b/requirements_v1.txt
@@ -0,0 +1,15 @@
+torch
+torchvision
+opencv-python
+numpy
+Pillow
+DateTime
+matplotlib
+plyfile
+HTML4Vision
+timm
+tensorboardX
+imgaug
+iopath
+imagecorruptions
+mmcv
\ No newline at end of file
diff --git a/requirements_v2.txt b/requirements_v2.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7467132b4c1bf148c9cf96ea9accdfb26144bec5
--- /dev/null
+++ b/requirements_v2.txt
@@ -0,0 +1,16 @@
+torch == 2.0.1
+torchvision == 0.15.2
+opencv-python
+numpy == 1.23.1
+xformers == 0.0.21
+Pillow
+DateTime
+matplotlib
+plyfile
+HTML4Vision
+timm
+tensorboardX
+imgaug
+iopath
+imagecorruptions
+mmcv
diff --git a/test.sh b/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e3b13163089928258f9b33cc55ae45bd02fc5574
--- /dev/null
+++ b/test.sh
@@ -0,0 +1,5 @@
+python mono/tools/test_scale_cano.py \
+ 'mono/configs/HourglassDecoder/convlarge.0.3_150.py' \
+ --load-from ./weight/convlarge_hourglass_0.3_150_step750k_v1.1.pth \
+ --test_data_path ./data/wild_demo \
+ --launcher None
\ No newline at end of file
diff --git a/test_kitti.sh b/test_kitti.sh
new file mode 100644
index 0000000000000000000000000000000000000000..98c43e39aa2b308b727eb2baa195a96a1a499cf3
--- /dev/null
+++ b/test_kitti.sh
@@ -0,0 +1,5 @@
+python mono/tools/test_scale_cano.py \
+ 'mono/configs/HourglassDecoder/test_kitti_convlarge_hourglass_0.3_150.py' \
+ --load-from ./weight/convlarge_hourglass_0.3_150_step750k_v1.1.pth \
+ --test_data_path ./data/kitti_demo/test_annotations.json \
+ --launcher None
\ No newline at end of file
diff --git a/test_nyu.sh b/test_nyu.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a39f96398427f2e44c3ab227f62f9afc41d6145f
--- /dev/null
+++ b/test_nyu.sh
@@ -0,0 +1,5 @@
+python mono/tools/test_scale_cano.py \
+ 'mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py' \
+ --load-from ./weight/convlarge_hourglass_0.3_150_step750k_v1.1.pth \
+ --test_data_path ./data/nyu_demo/test_annotations.json \
+ --launcher None
\ No newline at end of file
diff --git a/test_vit.sh b/test_vit.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e75c3e93de0a50fe5c330ec9bf909097a6f08b22
--- /dev/null
+++ b/test_vit.sh
@@ -0,0 +1,5 @@
+python mono/tools/test_scale_cano.py \
+ 'mono/configs/HourglassDecoder/vit.raft5.small.py' \
+ --load-from ./weight/metric_depth_vit_small_800k.pth \
+ --test_data_path ./data/wild_demo \
+ --launcher None
\ No newline at end of file
diff --git a/training/README.md b/training/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..37c2a1e31de407704da2929152bde7e0bbbd0f66
--- /dev/null
+++ b/training/README.md
@@ -0,0 +1,19 @@
+# Training
+
+**Re-implemented training codes in public environments by @JUGGHM**
+
+This is an re-implemented and verified version of the original training codes in private environments. Codes for overall framework, dataloaders, and losses are kept.
+However, we cannot provide the annotations ```json``` currently due to IP issues.
+
+You can either integrate our framework into your own codes (Recommanded), or prepare the datasets as following (Needs many efforts).
+
+### Config the pretrained checkpoints for ConvNeXt and DINOv2
+Download the checkpoints and config the paths in ```data_server_info/pretrained_weight.py```
+
+### Prepare the json files
+Prepare json files for different datasets in ```data_server_info/public_datasets.py```. Some tiny examples are also provided in ```data_server_info/annos*.json```.
+
+### Train
+```bash mono/scripts/training_scripts/train.sh```
+
+
diff --git a/training/data_server_info/__init__.py b/training/data_server_info/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec8374be5bc1a77bc72386ebf46cb50154217684
--- /dev/null
+++ b/training/data_server_info/__init__.py
@@ -0,0 +1,2 @@
+from .public_datasets import *
+from .pretrained_weight import *
\ No newline at end of file
diff --git a/training/data_server_info/annos_test_matterport3d_example.json b/training/data_server_info/annos_test_matterport3d_example.json
new file mode 100644
index 0000000000000000000000000000000000000000..af406d511362c8d83c858580bf12633749fb00c7
--- /dev/null
+++ b/training/data_server_info/annos_test_matterport3d_example.json
@@ -0,0 +1 @@
+{"files": [{"meta_data": "Matterport3D/data/2n8kARJN3HM/2n8kARJN3HM/meta/add134cc07e64d9d8524d0d9f96c4180_i1_5.pkl"}, {"meta_data": "Matterport3D/data/SN83YJsR3w2/SN83YJsR3w2/meta/4a87c9150e8442a1b8abc51ed5073ca0_i1_4.pkl"}, {"meta_data": "Matterport3D/data/Uxmj2M2itWa/Uxmj2M2itWa/meta/0cef156ab53041da97dd6a70d3d5af0b_i1_4.pkl"}, {"meta_data": "Matterport3D/data/yqstnuAEVhm/yqstnuAEVhm/meta/e9b4d8e951cb4712b3905c8f4c4dabb5_i2_1.pkl"}, {"meta_data": "Matterport3D/data/dhjEzFoUFzH/dhjEzFoUFzH/meta/3d1a8e5759a14f2a81e5d6e2f5045eca_i2_2.pkl"}]}
\ No newline at end of file
diff --git a/training/data_server_info/annos_test_normal_nyu_example.json b/training/data_server_info/annos_test_normal_nyu_example.json
new file mode 100644
index 0000000000000000000000000000000000000000..5e71142338f5cba3667b4fefbd7fffcaa298b676
--- /dev/null
+++ b/training/data_server_info/annos_test_normal_nyu_example.json
@@ -0,0 +1 @@
+{"files": [{"rgb": "NYU/nyu_normal/official/test/0000.png", "depth": "NYU/nyu_normal/official/test/0000_d.png", "cam_in": [518.8579, 519.4691, 325.58245, 253.73617], "normal": "NYU/nyu_normal/official/test/0000_n.png"}, {"rgb": "NYU/nyu_normal/official/test/0001.png", "depth": "NYU/nyu_normal/official/test/0001_d.png", "cam_in": [518.8579, 519.4691, 325.58245, 253.73617], "normal": "NYU/nyu_normal/official/test/0001_n.png"}, {"rgb": "NYU/nyu_normal/official/test/0008.png", "depth": "NYU/nyu_normal/official/test/0008_d.png", "cam_in": [518.8579, 519.4691, 325.58245, 253.73617], "normal": "NYU/nyu_normal/official/test/0008_n.png"}, {"rgb": "NYU/nyu_normal/official/test/0013.png", "depth": "NYU/nyu_normal/official/test/0013_d.png", "cam_in": [518.8579, 519.4691, 325.58245, 253.73617], "normal": "NYU/nyu_normal/official/test/0013_n.png"}, {"rgb": "NYU/nyu_normal/official/test/0014.png", "depth": "NYU/nyu_normal/official/test/0014_d.png", "cam_in": [518.8579, 519.4691, 325.58245, 253.73617], "normal": "NYU/nyu_normal/official/test/0014_n.png"}, {"rgb": "NYU/nyu_normal/official/test/0015.png", "depth": "NYU/nyu_normal/official/test/0015_d.png", "cam_in": [518.8579, 519.4691, 325.58245, 253.73617], "normal": "NYU/nyu_normal/official/test/0015_n.png"}, {"rgb": "NYU/nyu_normal/official/test/0016.png", "depth": "NYU/nyu_normal/official/test/0016_d.png", "cam_in": [518.8579, 519.4691, 325.58245, 253.73617], "normal": "NYU/nyu_normal/official/test/0016_n.png"}, {"rgb": "NYU/nyu_normal/official/test/0017.png", "depth": "NYU/nyu_normal/official/test/0017_d.png", "cam_in": [518.8579, 519.4691, 325.58245, 253.73617], "normal": "NYU/nyu_normal/official/test/0017_n.png"}]}
\ No newline at end of file
diff --git a/training/data_server_info/pretrained_weight.py b/training/data_server_info/pretrained_weight.py
new file mode 100644
index 0000000000000000000000000000000000000000..2752bd7411cef60e23c8deedccb167803df72f37
--- /dev/null
+++ b/training/data_server_info/pretrained_weight.py
@@ -0,0 +1,21 @@
+db_info={}
+
+
+
+db_info['checkpoint']={
+ 'db_root': 'tbd_weight_root', # Config your weight root!
+
+ # pretrained weight for vit
+ 'vit_small_reg': 'vit/dinov2_vits14_reg4_pretrain.pth',
+ 'vit_large_reg': 'vit/dinov2_vitl14_reg4_pretrain.pth',
+ 'vit_giant2_reg': 'vit/dinov2_vitg14_reg4_pretrain.pth',
+
+ 'vit_large': 'vit/dinov2_vitl14_pretrain.pth',
+
+ # pretrained weight for convnext
+ 'convnext_tiny': 'convnext/convnext_tiny_22k_1k_384.pth',
+ 'convnext_small': 'convnext/convnext_small_22k_1k_384.pth',
+ 'convnext_base': 'convnext/convnext_base_22k_1k_384.pth',
+ 'convnext_large': 'convnext/convnext_large_22k_1k_384.pth',
+ 'convnext_xlarge': 'convnext/convnext_xlarge_22k_1k_384_ema.pth',
+}
\ No newline at end of file
diff --git a/training/data_server_info/public_datasets.py b/training/data_server_info/public_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e316d85883628cc48d7dcf8fda81e1b4e1202a1
--- /dev/null
+++ b/training/data_server_info/public_datasets.py
@@ -0,0 +1,416 @@
+
+db_info={}
+
+
+#### DDAD Dataset
+# RGBD, consecutive frames, and ring cameras annotations
+db_info['DDAD']={
+ 'db_root': 'tbd_data_root', # Config your data root!
+ 'data_root': 'DDAD',
+ 'semantic_root': 'DDAD',
+ 'meta_data_root': 'DDAD',
+ 'train_annotations_path': 'DDAD/DDAD/annotations/train.json',
+ 'test_annotations_path': 'DDAD/DDAD/annotations/test.json',
+ 'val_annotations_path': 'DDAD/DDAD/annotations/val.json',
+}
+
+#### Mapillary Planet Scale Dataset
+# Single frame RGBD annotations
+db_info['Mapillary_PSD']={
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'Mapillary_PSD',
+ 'semantic_root': 'Mapillary_PSD',
+ 'train_annotations_path': 'Mapillary_PSD/Mapillary_PSD/annotations/train.json',
+ 'test_annotations_path': 'Mapillary_PSD/Mapillary_PSD/annotations/test.json',
+ 'val_annotations_path': 'Mapillary_PSD/Mapillary_PSD/annotations/val.json',
+}
+
+#### Cityscapes dataset
+# Cityscapes sequence dataset, RGBD and consecutive frames annotations
+db_info['Cityscapes_sequence'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'Cityscapes_sequence',
+ 'semantic_root': 'Cityscapes_sequence',
+ 'train_annotations_path': 'Cityscapes_sequence/Cityscapes_sequence/annotations/train.json',
+ 'test_annotations_path': 'Cityscapes_sequence/Cityscapes_sequence/annotations/test.json',
+ 'val_annotations_path': 'Cityscapes_sequence/Cityscapes_sequence/annotations/val.json',
+}
+# Cityscapes extra dataset, RGBD annotations
+db_info['Cityscapes_trainextra'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'Cityscapes_trainextra',
+ 'train_annotations_path': 'Cityscapes_trainextra/Cityscapes_trainextra/annotations/train.json',
+ 'test_annotations_path': 'Cityscapes_trainextra/Cityscapes_trainextra/annotations/test.json',
+ 'val_annotations_path': 'Cityscapes_trainextra/Cityscapes_trainextra/annotations/val.json',
+}
+db_info['Cityscapes_sequence_test'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'Cityscapes_sequence',
+ 'train_annotations_path': 'Cityscapes_sequence/Cityscapes_sequence/annotations/train.json',
+ 'test_annotations_path': 'Cityscapes_sequence/Cityscapes_sequence/annotations/test.json',
+ 'val_annotations_path': 'Cityscapes_sequence/Cityscapes_sequence/annotations/test.json',
+}
+
+#### Lyft dataset
+# Lyft dataset, RGBD, neighbouring cameras, and consecutive frames annotations
+db_info['Lyft'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'Lyft',
+ 'depth_root': 'Lyft',
+ 'meta_data_root': 'Lyft',
+ 'semantic_root': 'Lyft',
+ 'train_annotations_path': 'Lyft/Lyft/annotations/train.json',
+ 'test_annotations_path': 'Lyft/Lyft/annotations/test.json',
+ 'val_annotations_path': 'Lyft/Lyft/annotations/val.json',
+}
+# Lyft dataset, RGBD for ring cameras
+db_info['Lyft_ring'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'Lyft',
+ 'depth_root': 'Lyft',
+ 'meta_data_root': 'Lyft',
+ 'train_annotations_path': 'Lyft/Lyft/annotations/train.json',
+ 'test_annotations_path': 'Lyft/Lyft/annotations/test.json',
+ 'val_annotations_path': 'Lyft/Lyft/annotations/val.json',
+}
+
+#### DSEC dataset
+# DSEC dataset, RGBD and consecutive frames annotaitons
+db_info['DSEC'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'DSEC',
+ 'semantic_root': 'DSEC',
+ 'train_annotations_path': 'DSEC/DSEC/annotations/train.json',
+ 'test_annotations_path': 'DSEC/DSEC/annotations/test.json',
+ 'val_annotations_path': 'DSEC/DSEC/annotations/val.json',
+}
+
+#### Argovers2 Dataset
+# Argovers2 dataset, RGBD and neighbouring cameras annotaitons
+db_info['Argovers2'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'Argovers2',
+ 'depth_root': 'Argovers2',
+ 'meta_data_root': 'Argovers2',
+ 'train_annotations_path': 'Argovers2/Argovers2/annotations/train.json',
+ 'test_annotations_path': 'Argovers2/Argovers2/annotations/test.json',
+ 'val_annotations_path': 'Argovers2/Argovers2/annotations/val.json',
+}
+# Argovers2 dataset, RGBD and consecutive cameras annotaitons
+db_info['Argovers2_tmpl'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'Argovers2',
+ 'depth_root': 'Argovers2',
+ 'meta_data_root': 'Argovers2',
+ 'train_annotations_path': 'Argovers2/Argovers2/annotations/train.json',
+ 'test_annotations_path': 'Argovers2/Argovers2/annotations/test.json',
+ 'val_annotations_path': 'Argovers2/Argovers2/annotations/val.json',
+}
+
+#### DrivingStereo Dataset
+# DrivingStereo dataset, RGBD annotaitons for stereo data
+db_info['DrivingStereo'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'DrivingStereo',
+ 'semantic_root': 'DrivingStereo',
+ 'train_annotations_path': 'DrivingStereo/DrivingStereo/annotations/train.json',
+ 'test_annotations_path': 'DrivingStereo/DrivingStereo/annotations/test.json',
+ 'val_annotations_path': 'DrivingStereo/DrivingStereo/annotations/val.json',
+}
+# DrivingStereo dataset, RGBD and consecutive frames annotaitons for stereo data
+db_info['DrivingStereo_tmpl'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'DrivingStereo',
+ 'semantic_root': 'DrivingStereo',
+ 'train_annotations_path': 'DrivingStereo/DrivingStereo/annotations/train.json',
+ 'test_annotations_path': 'DrivingStereo/DrivingStereo/annotations/test.json',
+ 'val_annotations_path': 'DrivingStereo/DrivingStereo/annotations/val.json',
+}
+
+#### DIML Dataset
+# DIML dataset, RGBD annotaitons for stereo data
+db_info['DIML'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'DIML',
+ 'semantic_root': 'DIML',
+ 'train_annotations_path': 'DIML/DIML/anotation/train.json',
+ 'test_annotations_path': 'DIML/DIML/anotation/test.json',
+ 'val_annotations_path': 'DIML/DIML/anotation/val.json',
+}
+
+db_info['NuScenes'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'NuScenes',
+ 'train_annotations_path': 'NuScenes/NuScenes/annotations/train.json',
+ 'test_annotations_path': 'NuScenes/NuScenes/annotations/test.json',
+ 'val_annotations_path': 'NuScenes/NuScenes/annotations/val.json',
+}
+db_info['NuScenes_tmpl'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'NuScenes',
+ 'train_annotations_path': 'NuScenes/NuScenes/annotations/train.json',
+ 'test_annotations_path': 'NuScenes/NuScenes/annotations/test.json',
+ 'val_annotations_path': 'NuScenes/NuScenes/annotations/val.json',
+}
+
+
+# Pandaset, RGBD + tmpl dataset
+db_info['Pandaset'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'Pandaset',
+ 'meta_data_root': 'Pandaset',
+ 'semantic_root': 'Pandaset',
+ 'train_annotations_path': 'Pandaset/Pandaset/annotations/train.json',
+ 'test_annotations_path': 'Pandaset/Pandaset/annotations/test.json',
+ 'val_annotations_path': 'Pandaset/Pandaset/annotations/val.json',
+}
+db_info['Pandaset_ring'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'Pandaset',
+ 'meta_data_root': 'Pandaset',
+ 'semantic_root': 'Pandaset',
+ 'train_annotations_path': 'Pandaset/Pandaset/annotations/train.json',
+ 'test_annotations_path': 'Pandaset/Pandaset/annotations/test.json',
+ 'val_annotations_path': 'Pandaset/Pandaset/annotations/val.json',
+}
+
+# UASOL, RGBD + tmpl dataset
+db_info['UASOL'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'UASOL_data',
+ 'meta_data_root': 'UASOL_data',
+ 'semantic_root': 'UASOL_data',
+ 'train_annotations_path': 'UASOL_data/UASOL_data/annotations/train.json',
+ 'test_annotations_path': 'UASOL_data/UASOL_data/annotations/test.json',
+ 'val_annotations_path': 'UASOL_data/UASOL_data/annotations/test.json',
+}
+
+# Taskonomy, RGBD dataset
+db_info['Taskonomy'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'Taskonomy',
+ 'meta_data_root': 'Taskonomy',
+ 'semantic_root': 'Taskonomy',
+ 'normal_root': 'Taskonomy',
+
+ 'train_annotations_path': 'Taskonomy/Taskonomy/annotations/train.json',
+ 'test_annotations_path': 'Taskonomy/Taskonomy/annotations/test.json',
+ 'val_annotations_path': 'Taskonomy/Taskonomy/annotations/test.json',
+}
+
+### WebStereo Datasets
+# HRWSI/Holopix dataset, RGBD and sky masks annotations
+db_info['HRWSI_Holopix'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'WebStereo',
+ 'train_annotations_path': 'WebStereo/annotations/train.json',
+ 'test_annotations_path': 'WebStereo/annotations/test.json',
+ 'val_annotations_path': 'WebStereo/annotations/val.json',
+}
+
+### Waymo Datasets
+db_info['Waymo'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'Waymo',
+ 'meta_data_root': 'Waymo',
+ 'semantic_root': 'Waymo',
+ 'train_annotations_path': 'Waymo/Waymo/annotations/training_annos_all_filter.json',
+ 'test_annotations_path': 'Waymo/Waymo/annotations/testing_annos_all_filter.json',
+ 'val_annotations_path': 'Waymo/Waymo/annotations/validation_annos_all_filter.json',
+}
+
+
+# DIODE, RGBD dataset
+db_info['DIODE'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'DIODE',
+ 'depth_mask_root': 'DIODE',
+ 'normal_root': 'DIODE',
+ 'train_annotations_path': 'DIODE/DIODE/annotations/train.json',
+ 'test_annotations_path': 'DIODE/DIODE/annotations/test.json',
+ 'val_annotations_path': 'DIODE/DIODE/annotations/val.json',
+}
+db_info['DIODE_indoor'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'DIODE',
+ 'depth_mask_root': 'DIODE',
+ 'train_annotations_path': 'DIODE/DIODE/annotations/train.json',
+ 'test_annotations_path': 'DIODE/DIODE/annotations/test.json',
+ 'val_annotations_path': 'DIODE/DIODE/annotations/val.json',
+}
+db_info['DIODE_outdoor'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'DIODE',
+ 'depth_mask_root': 'DIODE',
+ 'normal_root': 'DIODE',
+ 'train_annotations_path': 'DIODE/DIODE/annotations/train.json',
+ 'test_annotations_path': 'DIODE/DIODE/annotations/test.json',
+ 'val_annotations_path': 'DIODE/DIODE/annotations/val.json',
+}
+db_info['ETH3D'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'ETH3D',
+ 'depth_mask_root': 'ETH3D',
+ 'train_annotations_path': 'ETH3D/ETH3D/annotations/test.json',
+ 'test_annotations_path': 'ETH3D/ETH3D/annotations/test.json',
+ 'val_annotations_path': 'ETH3D/ETH3D/annotations/test.json',
+}
+# NYU, RGBD dataset
+db_info['NYU'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'NYU',
+ 'normal_root': 'NYU',
+ #'train_annotations_path': 'NYU/NYU/annotations/train.json',
+ 'train_annotations_path': 'NYU/NYU/annotations/train_normal.json',
+ #'test_annotations_path': 'NYU/NYU/annotations/test.json',
+ 'test_annotations_path': 'NYU/NYU/annotations/test_normal.json',
+ 'val_annotations_path': 'NYU/NYU/annotations/test.json',
+}
+# ScanNet, RGBD dataset
+db_info['ScanNet'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'ScanNet',
+ 'train_annotations_path': 'ScanNet/ScanNet/annotations/train.json',
+ 'test_annotations_path': 'ScanNet/ScanNet/annotations/test.json',
+ 'val_annotations_path': 'ScanNet/ScanNet/annotations/test.json',
+}
+# KITTI, RGBD dataset
+db_info['KITTI'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': '',
+ 'train_annotations_path': 'KITTI/KITTI/annotations/eigen_train.json',
+ 'test_annotations_path': 'KITTI/KITTI/annotations/eigen_test.json',
+ 'val_annotations_path': 'KITTI/KITTI/annotations/eigen_test.json',
+}
+
+
+########### new training data
+# Blended_mvg, RGBD dataset
+db_info['BlendedMVG_omni'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'Blended_mvg',
+ 'meta_data_root': 'Blended_mvg',
+ 'normal_root': 'Blended_mvg',
+ 'train_annotations_path': 'Blended_mvg/Blended_mvg/annotations/train.json',
+ 'test_annotations_path': 'Blended_mvg/Blended_mvg/annotations/test.json',
+ 'val_annotations_path': 'Blended_mvg/Blended_mvg/annotations/val.json',
+}
+
+# HM3D, RGBD dataset
+db_info['HM3D'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'HM3D',
+ 'meta_data_root': 'HM3D',
+ 'normal_root': 'HM3D',
+ 'train_annotations_path': 'HM3D/HM3d_omnidata/annotations/train.json', #',
+ 'test_annotations_path': 'HM3D/HM3d_omnidata/annotations/val.json',
+ 'val_annotations_path': 'HM3D/HM3d_omnidata/annotations/test.json',
+}
+
+# LeddarPixSet, RGBD dataset, some errors in the data
+db_info['LeddarPixSet'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'LeddarPixSet',
+ 'meta_data_root': 'LeddarPixSet',
+ 'train_annotations_path': 'LeddarPixSet/LeddarPixSet/annotations/train.json',
+ 'test_annotations_path': 'LeddarPixSet/LeddarPixSet/annotations/test.json',
+ 'val_annotations_path': 'LeddarPixSet/LeddarPixSet/annotations/val.json',
+}
+
+# RGBD dataset
+db_info['Replica'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'Replica',
+ 'meta_data_root': 'Replica',
+ 'normal_root': 'Replica',
+ 'train_annotations_path': 'Replica/replica/annotations/train.json',
+ 'test_annotations_path': 'Replica/replica/annotations/test.json',
+ 'val_annotations_path': 'Replica/replica/annotations/val.json',
+}
+
+db_info['Replica_gso'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'Replica',
+ 'meta_data_root': 'Replica',
+ 'normal_root': 'Replica',
+ 'train_annotations_path': 'Replica/replica_gso/annotations/train.json',
+ 'test_annotations_path': 'Replica/replica_gso/annotations/test.json',
+ 'val_annotations_path': 'Replica/replica_gso/annotations/val.json',
+}
+
+db_info['Matterport3D'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'Matterport3D',
+ 'meta_data_root': 'Matterport3D',
+ 'normal_root': 'Matterport3D',
+ 'train_annotations_path': 'Matterport3D/Matterport3D/annotations/train.json',
+ 'test_annotations_path': 'Matterport3D/Matterport3D/annotations/test.json',
+ 'val_annotations_path': 'Matterport3D/Matterport3D/annotations/test.json',
+}
+
+db_info['S3DIS'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 's3dis',
+ 'meta_data_root': 's3dis',
+ 'normal_root': 's3dis',
+ 'train_annotations_path': 's3dis/s3dis/annotations/train.json',
+ 'test_annotations_path': 's3dis/s3dis/annotations/test.json',
+ 'val_annotations_path': 's3dis/s3dis/annotations/test.json',
+}
+
+db_info['Seasons4'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': '4seasons/4seasons',
+ 'meta_data_root': '4seasons/4seasons',
+ 'train_annotations_path': '4seasons/4seasons/annotations/train.json',
+ 'test_annotations_path': '4seasons/4seasons/annotations/test.json',
+ 'val_annotations_path': '4seasons/4seasons/annotations/test.json',
+}
+
+db_info['Virtual_KITTI'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'virtual_kitti',
+ 'meta_data_root': 'virtual_kitti',
+ 'semantic_root': 'virtual_kitti',
+ 'train_annotations_path': 'virtual_kitti/virtual_kitti/annotations/train.json',
+ 'test_annotations_path': 'virtual_kitti/virtual_kitti/annotations/test.json',
+ 'val_annotations_path': 'virtual_kitti/virtual_kitti/annotations/test.json',
+}
+
+db_info['IBIMS'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': '',
+ 'train_annotations_path': 'iBims-1/annotations/train.json',
+ 'test_annotations_path': 'iBims-1/annotations/test.json',
+ 'val_annotations_path': 'iBims-1/annotations/test.json',
+}
+
+db_info['ScanNetAll'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': 'scannet',
+ 'normal_root': 'scannet',
+ 'meta_data_root': 'scannet',
+ 'train_annotations_path': 'scannet/scannet/annotations/train.json',
+ 'test_annotations_path': 'scannet/scannet/annotations/test.json',
+ 'val_annotations_path': 'scannet/scannet/annotations/test.json',
+}
+
+db_info['Hypersim'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': '',
+ 'meta_data_root': '',
+ 'normal_root': '',
+ # 'semantic_root': '', # Semantic tags without sky, see https://github.com/apple/ml-hypersim/blob/main/code/cpp/tools/scene_annotation_tool/semantic_label_descs.csv
+ 'train_annotations_path': 'Hypersim/annotations/train.json',
+ 'test_annotations_path': 'Hypersim/annotations/test.json',
+ 'val_annotations_path': 'Hypersim/annotations/test.json',
+}
+
+db_info['DIML_indoor'] = {
+ 'db_root': 'tbd_data_root',
+ 'data_root': '',
+ # 'semantic_root': '',
+ 'train_annotations_path': 'DIML_indoor_new/annotations/train.json',
+ 'test_annotations_path': 'DIML_indoor_new/annotations/test.json',
+ 'val_annotations_path': 'DIML_indoor_new/annotations/test.json',
+}
\ No newline at end of file
diff --git a/training/mono/__init__.py b/training/mono/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/training/mono/configs/RAFTDecoder/vit.raft5.giant2.kitti.py b/training/mono/configs/RAFTDecoder/vit.raft5.giant2.kitti.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b6a7202ddeb34e49ca742282bab357bba5dc26d
--- /dev/null
+++ b/training/mono/configs/RAFTDecoder/vit.raft5.giant2.kitti.py
@@ -0,0 +1,132 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py',
+
+ '../_base_/datasets/nyu.py',
+ '../_base_/datasets/kitti.py'
+ ]
+
+import numpy as np
+model=dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ ),
+)
+
+# loss method
+losses=dict(
+ decoder_losses=[
+ dict(type='VNLoss', sample_ratio=0.2, loss_weight=0.1),
+ dict(type='GRUSequenceLoss', loss_weight=1.0, loss_gamma=0.9, stereo_sup=0),
+ dict(type='DeNoConsistencyLoss', loss_weight=0.001, loss_fn='CEL', scale=2)
+ ],
+)
+
+data_array = [
+
+ [
+ dict(KITTI='KITTI_dataset'),
+ ],
+]
+
+
+
+# configs of the canonical space
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, 200),
+# crop_size=(544, 1216),
+# crop_size = (544, 992),
+ crop_size = (616, 1064), # %28 = 0
+)
+
+# online evaluation
+# evaluation = dict(online_eval=True, interval=1000, metrics=['abs_rel', 'delta1', 'rmse'], multi_dataset_eval=True)
+#log_interval = 100
+
+interval = 4000
+log_interval = 100
+evaluation = dict(
+ online_eval=False,
+ interval=interval,
+ metrics=['abs_rel', 'delta1', 'rmse', 'normal_mean', 'normal_rmse', 'normal_a1'],
+ multi_dataset_eval=True,
+ exclude=['DIML_indoor', 'GL3D', 'Tourism', 'MegaDepth'],
+)
+
+# save checkpoint during training, with '*_AMP' is employing the automatic mix precision training
+checkpoint_config = dict(by_epoch=False, interval=interval)
+runner = dict(type='IterBasedRunner_AMP', max_iters=20010)
+
+# optimizer
+optimizer = dict(
+ type='AdamW',
+ encoder=dict(lr=5e-7, betas=(0.9, 0.999), weight_decay=0, eps=1e-10),
+ decoder=dict(lr=1e-5, betas=(0.9, 0.999), weight_decay=0, eps=1e-10),
+ strict_match = True
+)
+# schedule
+lr_config = dict(policy='poly',
+ warmup='linear',
+ warmup_iters=20,
+ warmup_ratio=1e-6,
+ power=0.9, min_lr=1e-8, by_epoch=False)
+
+acc_batch = 1
+batchsize_per_gpu = 2
+thread_per_gpu = 2
+
+KITTI_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
diff --git a/training/mono/configs/RAFTDecoder/vit.raft5.giant2.nyu.py b/training/mono/configs/RAFTDecoder/vit.raft5.giant2.nyu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c59676811aec1a05917adeca2c1f43a46e9bec88
--- /dev/null
+++ b/training/mono/configs/RAFTDecoder/vit.raft5.giant2.nyu.py
@@ -0,0 +1,136 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py',
+
+ '../_base_/datasets/nyu.py',
+ '../_base_/datasets/kitti.py'
+ ]
+
+import numpy as np
+model=dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ ),
+)
+
+# loss method
+losses=dict(
+ decoder_losses=[
+ dict(type='VNLoss', sample_ratio=0.2, loss_weight=1.0),
+ dict(type='GRUSequenceLoss', loss_weight=1.0, loss_gamma=0.9, stereo_sup=0),
+ dict(type='NormalBranchLoss', loss_weight=1.5, loss_fn='NLL_ours_GRU'),
+ dict(type='DeNoConsistencyLoss', loss_weight=0.001, loss_fn='CEL', scale=2),
+ dict(type='HDNRandomLoss', loss_weight=0.5, random_num=10),
+ dict(type='HDSNRandomLoss', loss_weight=0.5, random_num=20, batch_limit=4),
+ dict(type='PWNPlanesLoss', loss_weight=1),
+ ],
+)
+
+data_array = [
+
+ [
+ dict(NYU='NYU_dataset'),
+ ],
+]
+
+
+
+# configs of the canonical space
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, 200),
+# crop_size=(544, 1216),
+# crop_size = (544, 992),
+ crop_size = (616, 1064), # %28 = 0
+)
+
+# online evaluation
+# evaluation = dict(online_eval=True, interval=1000, metrics=['abs_rel', 'delta1', 'rmse'], multi_dataset_eval=True)
+#log_interval = 100
+
+interval = 4000
+log_interval = 200
+evaluation = dict(
+ online_eval=False,
+ interval=interval,
+ metrics=['abs_rel', 'delta1', 'rmse', 'normal_mean', 'normal_rmse', 'normal_a1'],
+ multi_dataset_eval=True,
+ exclude=['DIML_indoor', 'GL3D', 'Tourism', 'MegaDepth'],
+)
+
+# save checkpoint during training, with '*_AMP' is employing the automatic mix precision training
+checkpoint_config = dict(by_epoch=False, interval=interval)
+runner = dict(type='IterBasedRunner_AMP', max_iters=20010)
+
+# optimizer
+optimizer = dict(
+ type='AdamW',
+ encoder=dict(lr=5e-7, betas=(0.9, 0.999), weight_decay=0, eps=1e-10),
+ decoder=dict(lr=1e-5, betas=(0.9, 0.999), weight_decay=0, eps=1e-10),
+ strict_match = True
+)
+# schedule
+lr_config = dict(policy='poly',
+ warmup='linear',
+ warmup_iters=20,
+ warmup_ratio=1e-6,
+ power=0.9, min_lr=1e-8, by_epoch=False)
+
+acc_batch = 1
+batchsize_per_gpu = 2
+thread_per_gpu = 2
+
+NYU_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
diff --git a/training/mono/configs/RAFTDecoder/vit.raft5.giant2.py b/training/mono/configs/RAFTDecoder/vit.raft5.giant2.py
new file mode 100644
index 0000000000000000000000000000000000000000..51cd0839c63c475a6cd9bf365b9b02229d67156b
--- /dev/null
+++ b/training/mono/configs/RAFTDecoder/vit.raft5.giant2.py
@@ -0,0 +1,1048 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py',
+
+ '../_base_/datasets/ddad.py',
+ '../_base_/datasets/_data_base_.py',
+ '../_base_/datasets/argovers2.py',
+ '../_base_/datasets/cityscapes.py',
+ '../_base_/datasets/drivingstereo.py',
+ '../_base_/datasets/dsec.py',
+ '../_base_/datasets/lyft.py',
+ '../_base_/datasets/mapillary_psd.py',
+ '../_base_/datasets/diml.py',
+ '../_base_/datasets/taskonomy.py',
+ '../_base_/datasets/uasol.py',
+ '../_base_/datasets/pandaset.py',
+ '../_base_/datasets/waymo.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py',
+
+ '../_base_/datasets/hm3d.py',
+ '../_base_/datasets/matterport3d.py',
+ '../_base_/datasets/replica.py',
+ '../_base_/datasets/vkitti.py',
+ ]
+
+import numpy as np
+model=dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ ),
+)
+
+# loss method
+losses=dict(
+ decoder_losses=[
+ dict(type='VNLoss', sample_ratio=0.2, loss_weight=1.0),
+ dict(type='GRUSequenceLoss', loss_weight=0.5, loss_gamma=0.9, stereo_sup=0.0),
+ dict(type='SkyRegularizationLoss', loss_weight=0.001, sample_ratio=0.4, regress_value=200, normal_regress=[0, 0, -1]),
+ dict(type='HDNRandomLoss', loss_weight=0.5, random_num=10),
+ dict(type='HDSNRandomLoss', loss_weight=0.5, random_num=20, batch_limit=4),
+ dict(type='PWNPlanesLoss', loss_weight=1),
+ dict(type='NormalBranchLoss', loss_weight=1.5, loss_fn='NLL_ours_GRU'),
+ dict(type='DeNoConsistencyLoss', loss_weight=0.01, loss_fn='CEL', scale=2, depth_detach=True)
+ ],
+ gru_losses=[
+ dict(type='SkyRegularizationLoss', loss_weight=0.001, sample_ratio=0.4, regress_value=200, normal_regress=[0, 0, -1]),
+ ],
+)
+
+data_array = [
+ # Outdoor 1
+ [
+ dict(UASOL='UASOL_dataset'), #13.6w
+ dict(Cityscapes_trainextra='Cityscapes_dataset'), #1.8w
+ dict(Cityscapes_sequence='Cityscapes_dataset'), #13.5w
+ dict(DIML='DIML_dataset'), # 12.2w
+ dict(Waymo='Waymo_dataset'), # 99w
+ ],
+ # Outdoor 2
+ [
+ dict(DSEC='DSEC_dataset'),
+ dict(Mapillary_PSD='MapillaryPSD_dataset'), # 74.2w
+ dict(DrivingStereo='DrivingStereo_dataset'), # 17.6w
+ dict(Argovers2='Argovers2_dataset'), # 285.6w
+ ],
+ # Outdoor 3
+ [
+ dict(Lyft='Lyft_dataset'), #15.8w
+ dict(DDAD='DDAD_dataset'), #7.4w
+ dict(Pandaset='Pandaset_dataset'), #3.8w
+ dict(Virtual_KITTI='VKITTI_dataset'), # 3.7w # syn
+ ],
+ #Indoor 1
+ [
+ dict(Replica='Replica_dataset'), # 5.6w # syn
+ dict(Replica_gso='Replica_dataset'), # 10.7w # syn
+ dict(Hypersim='Hypersim_dataset'), # 2.4w
+ dict(ScanNetAll='ScanNetAll_dataset'),
+ ],
+ # Indoor 2
+ [
+ dict(Taskonomy='Taskonomy_dataset'), #447.2w
+ dict(Matterport3D='Matterport3D_dataset'), #14.4w
+ dict(HM3D='HM3D_dataset'), # 200w, very noisy, sampled some data
+ ],
+]
+
+
+
+# configs of the canonical space
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, 200),
+# crop_size=(544, 1216),
+# crop_size = (544, 992),
+ crop_size = (616, 1064), # %28 = 0
+)
+
+log_interval = 100
+acc_batch = 1
+# online evaluation
+# evaluation = dict(online_eval=True, interval=1000, metrics=['abs_rel', 'delta1', 'rmse'], multi_dataset_eval=True)
+interval = 40000
+evaluation = dict(
+ online_eval=False,
+ interval=interval,
+ metrics=['abs_rel', 'delta1', 'rmse', 'normal_mean', 'normal_rmse', 'normal_a1'],
+ multi_dataset_eval=True,
+ exclude=['DIML_indoor', 'GL3D', 'Tourism', 'MegaDepth'],
+)
+
+# save checkpoint during training, with '*_AMP' is employing the automatic mix precision training
+checkpoint_config = dict(by_epoch=False, interval=interval)
+runner = dict(type='IterBasedRunner_AMP', max_iters=800010)
+
+# optimizer
+optimizer = dict(
+ type='AdamW',
+# encoder=dict(lr=1e-4, betas=(0.9, 0.999), weight_decay=0.01, eps=1e-6),
+ encoder=dict(lr=8e-6, betas=(0.9, 0.999), weight_decay=1e-3, eps=1e-6),
+ decoder=dict(lr=1e-4, betas=(0.9, 0.999), weight_decay=0.01, eps=1e-6),
+ #strict_match=True
+)
+# schedule
+lr_config = dict(policy='poly',
+ warmup='linear',
+ warmup_iters=1000,
+ warmup_ratio=1e-6,
+ power=0.9, min_lr=1e-6, by_epoch=False)
+
+batchsize_per_gpu = 3
+thread_per_gpu = 1
+
+Argovers2_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Cityscapes_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+DIML_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Lyft_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+DDAD_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ # sample_size = 1200,
+ ),
+ ))
+DSEC_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+DrivingStereo_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+MapillaryPSD_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Pandaset_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Taskonomy_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+UASOL_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Waymo_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Matterport3D_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Replica_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+VKITTI_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+HM3D_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.75, 1.3),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+BlendedMVG_omni_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.75, 1.3),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ ),
+ ))
+ScanNetAll_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Hypersim_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
\ No newline at end of file
diff --git a/training/mono/configs/RAFTDecoder/vit.raft5.large.py b/training/mono/configs/RAFTDecoder/vit.raft5.large.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7ae460bfa062b3ebf940092760e282529d9b748
--- /dev/null
+++ b/training/mono/configs/RAFTDecoder/vit.raft5.large.py
@@ -0,0 +1,1047 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py',
+
+ '../_base_/datasets/ddad.py',
+ '../_base_/datasets/_data_base_.py',
+ '../_base_/datasets/argovers2.py',
+ '../_base_/datasets/cityscapes.py',
+ '../_base_/datasets/drivingstereo.py',
+ '../_base_/datasets/dsec.py',
+ '../_base_/datasets/lyft.py',
+ '../_base_/datasets/mapillary_psd.py',
+ '../_base_/datasets/diml.py',
+ '../_base_/datasets/taskonomy.py',
+ '../_base_/datasets/uasol.py',
+ '../_base_/datasets/pandaset.py',
+ '../_base_/datasets/waymo.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py',
+
+ '../_base_/datasets/hm3d.py',
+ '../_base_/datasets/matterport3d.py',
+ '../_base_/datasets/replica.py',
+ '../_base_/datasets/vkitti.py',
+ ]
+
+import numpy as np
+model=dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ ),
+)
+
+# loss method
+losses=dict(
+ decoder_losses=[
+ dict(type='VNLoss', sample_ratio=0.2, loss_weight=1.0),
+ dict(type='GRUSequenceLoss', loss_weight=0.5, loss_gamma=0.9, stereo_sup=0.0),
+ dict(type='SkyRegularizationLoss', loss_weight=0.001, sample_ratio=0.4, regress_value=200, normal_regress=[0, 0, -1]),
+ dict(type='HDNRandomLoss', loss_weight=0.5, random_num=10),
+ dict(type='HDSNRandomLoss', loss_weight=0.5, random_num=20, batch_limit=4),
+ dict(type='PWNPlanesLoss', loss_weight=1),
+ dict(type='NormalBranchLoss', loss_weight=1.0, loss_fn='NLL_ours_GRU'),
+ dict(type='DeNoConsistencyLoss', loss_weight=0.01, loss_fn='CEL', scale=2, depth_detach=True)
+ ],
+ gru_losses=[
+ dict(type='SkyRegularizationLoss', loss_weight=0.001, sample_ratio=0.4, regress_value=200, normal_regress=[0, 0, -1]),
+ ],
+)
+
+data_array = [
+ # Outdoor 1
+ [
+ dict(UASOL='UASOL_dataset'), #13.6w
+ dict(Cityscapes_trainextra='Cityscapes_dataset'), #1.8w
+ dict(Cityscapes_sequence='Cityscapes_dataset'), #13.5w
+ dict(DIML='DIML_dataset'), # 12.2w
+ dict(Waymo='Waymo_dataset'), # 99w
+ ],
+ # Outdoor 2
+ [
+ dict(DSEC='DSEC_dataset'),
+ dict(Mapillary_PSD='MapillaryPSD_dataset'), # 74.2w
+ dict(DrivingStereo='DrivingStereo_dataset'), # 17.6w
+ dict(Argovers2='Argovers2_dataset'), # 285.6w
+ ],
+ # Outdoor 3
+ [
+ dict(Lyft='Lyft_dataset'), #15.8w
+ dict(DDAD='DDAD_dataset'), #7.4w
+ dict(Pandaset='Pandaset_dataset'), #3.8w
+ dict(Virtual_KITTI='VKITTI_dataset'), # 3.7w # syn
+ ],
+ #Indoor 1
+ [
+ dict(Replica='Replica_dataset'), # 5.6w # syn
+ dict(Replica_gso='Replica_dataset'), # 10.7w # syn
+ dict(Hypersim='Hypersim_dataset'), # 2.4w
+ dict(ScanNetAll='ScanNetAll_dataset'),
+ ],
+ # Indoor 2
+ [
+ dict(Taskonomy='Taskonomy_dataset'), #447.2w
+ dict(Matterport3D='Matterport3D_dataset'), #14.4w
+ dict(HM3D='HM3D_dataset'), # 200w, very noisy, sampled some data
+ ],
+]
+
+
+
+# configs of the canonical space
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, 200),
+# crop_size=(544, 1216),
+# crop_size = (544, 992),
+ crop_size = (616, 1064), # %28 = 0
+)
+
+log_interval = 100
+# online evaluation
+# evaluation = dict(online_eval=True, interval=1000, metrics=['abs_rel', 'delta1', 'rmse'], multi_dataset_eval=True)
+interval = 20000
+evaluation = dict(
+ #online_eval=True,
+ online_eval=False,
+ interval=interval,
+ metrics=['abs_rel', 'delta1', 'rmse', 'normal_mean', 'normal_rmse', 'normal_a1'],
+ multi_dataset_eval=True,
+ exclude=['DIML_indoor', 'GL3D', 'Tourism', 'MegaDepth'],
+)
+
+# save checkpoint during training, with '*_AMP' is employing the automatic mix precision training
+checkpoint_config = dict(by_epoch=False, interval=interval)
+runner = dict(type='IterBasedRunner_AMP', max_iters=800010)
+
+# optimizer
+optimizer = dict(
+ type='AdamW',
+# encoder=dict(lr=1e-4, betas=(0.9, 0.999), weight_decay=0.01, eps=1e-6),
+ encoder=dict(lr=1e-5, betas=(0.9, 0.999), weight_decay=1e-3, eps=1e-6),
+ decoder=dict(lr=1e-4, betas=(0.9, 0.999), weight_decay=0.01, eps=1e-6),
+)
+# schedule
+lr_config = dict(policy='poly',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=1e-6,
+ power=0.9, min_lr=1e-6, by_epoch=False)
+
+batchsize_per_gpu = 4
+thread_per_gpu = 4
+
+Argovers2_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Cityscapes_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+DIML_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Lyft_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+DDAD_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ # sample_size = 1200,
+ ),
+ ))
+DSEC_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+DrivingStereo_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+MapillaryPSD_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Pandaset_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Taskonomy_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+UASOL_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Waymo_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Matterport3D_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Replica_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+VKITTI_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+HM3D_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.75, 1.3),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+BlendedMVG_omni_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.75, 1.3),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ ),
+ ))
+ScanNetAll_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Hypersim_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
\ No newline at end of file
diff --git a/training/mono/configs/RAFTDecoder/vit.raft5.small.py b/training/mono/configs/RAFTDecoder/vit.raft5.small.py
new file mode 100644
index 0000000000000000000000000000000000000000..484e1df74f598faf4bd08c9698ab512f92ebb3f5
--- /dev/null
+++ b/training/mono/configs/RAFTDecoder/vit.raft5.small.py
@@ -0,0 +1,1047 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py',
+
+ '../_base_/datasets/ddad.py',
+ '../_base_/datasets/_data_base_.py',
+ '../_base_/datasets/argovers2.py',
+ '../_base_/datasets/cityscapes.py',
+ '../_base_/datasets/drivingstereo.py',
+ '../_base_/datasets/dsec.py',
+ '../_base_/datasets/lyft.py',
+ '../_base_/datasets/mapillary_psd.py',
+ '../_base_/datasets/diml.py',
+ '../_base_/datasets/taskonomy.py',
+ '../_base_/datasets/uasol.py',
+ '../_base_/datasets/pandaset.py',
+ '../_base_/datasets/waymo.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py',
+
+ '../_base_/datasets/hm3d.py',
+ '../_base_/datasets/matterport3d.py',
+ '../_base_/datasets/replica.py',
+ '../_base_/datasets/vkitti.py',
+ ]
+
+import numpy as np
+model=dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=4,
+ n_downsample=2,
+ detach=False,
+ ),
+)
+
+# loss method
+losses=dict(
+ decoder_losses=[
+ dict(type='VNLoss', sample_ratio=0.2, loss_weight=1.0),
+ dict(type='GRUSequenceLoss', loss_weight=0.5, loss_gamma=0.9, stereo_sup=0.0),
+ dict(type='SkyRegularizationLoss', loss_weight=0.001, sample_ratio=0.4, regress_value=200, normal_regress=[0, 0, -1]),
+ dict(type='HDNRandomLoss', loss_weight=0.5, random_num=10),
+ dict(type='HDSNRandomLoss', loss_weight=0.5, random_num=20, batch_limit=4),
+ dict(type='PWNPlanesLoss', loss_weight=1),
+ dict(type='NormalBranchLoss', loss_weight=1.0, loss_fn='NLL_ours_GRU'),
+ dict(type='DeNoConsistencyLoss', loss_weight=0.01, loss_fn='CEL', scale=2, depth_detach=True)
+ ],
+ gru_losses=[
+ dict(type='SkyRegularizationLoss', loss_weight=0.001, sample_ratio=0.4, regress_value=200, normal_regress=[0, 0, -1]),
+ ],
+)
+
+data_array = [
+ # Outdoor 1
+ [
+ dict(UASOL='UASOL_dataset'), #13.6w
+ dict(Cityscapes_trainextra='Cityscapes_dataset'), #1.8w
+ dict(Cityscapes_sequence='Cityscapes_dataset'), #13.5w
+ dict(DIML='DIML_dataset'), # 12.2w
+ dict(Waymo='Waymo_dataset'), # 99w
+ ],
+ # Outdoor 2
+ [
+ dict(DSEC='DSEC_dataset'),
+ dict(Mapillary_PSD='MapillaryPSD_dataset'), # 74.2w
+ dict(DrivingStereo='DrivingStereo_dataset'), # 17.6w
+ dict(Argovers2='Argovers2_dataset'), # 285.6w
+ ],
+ # Outdoor 3
+ [
+ dict(Lyft='Lyft_dataset'), #15.8w
+ dict(DDAD='DDAD_dataset'), #7.4w
+ dict(Pandaset='Pandaset_dataset'), #3.8w
+ dict(Virtual_KITTI='VKITTI_dataset'), # 3.7w # syn
+ ],
+ #Indoor 1
+ [
+ dict(Replica='Replica_dataset'), # 5.6w # syn
+ dict(Replica_gso='Replica_dataset'), # 10.7w # syn
+ dict(Hypersim='Hypersim_dataset'), # 2.4w
+ dict(ScanNetAll='ScanNetAll_dataset'),
+ ],
+ # Indoor 2
+ [
+ dict(Taskonomy='Taskonomy_dataset'), #447.2w
+ dict(Matterport3D='Matterport3D_dataset'), #14.4w
+ dict(HM3D='HM3D_dataset'), # 200w, very noisy, sampled some data
+ ],
+]
+
+
+
+# configs of the canonical space
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, 200),
+# crop_size=(544, 1216),
+# crop_size = (544, 992),
+ crop_size = (616, 1064), # %28 = 0
+)
+
+log_interval = 100
+# online evaluation
+# evaluation = dict(online_eval=True, interval=1000, metrics=['abs_rel', 'delta1', 'rmse'], multi_dataset_eval=True)
+interval = 20000
+evaluation = dict(
+ #online_eval=True,
+ online_eval=False,
+ interval=interval,
+ metrics=['abs_rel', 'delta1', 'rmse', 'normal_mean', 'normal_rmse', 'normal_a1'],
+ multi_dataset_eval=True,
+ exclude=['DIML_indoor', 'GL3D', 'Tourism', 'MegaDepth'],
+)
+
+# save checkpoint during training, with '*_AMP' is employing the automatic mix precision training
+checkpoint_config = dict(by_epoch=False, interval=interval)
+runner = dict(type='IterBasedRunner_AMP', max_iters=800010)
+
+# optimizer
+optimizer = dict(
+ type='AdamW',
+# encoder=dict(lr=1e-4, betas=(0.9, 0.999), weight_decay=0.01, eps=1e-6),
+ encoder=dict(lr=1e-5, betas=(0.9, 0.999), weight_decay=1e-3, eps=1e-6),
+ decoder=dict(lr=1e-4, betas=(0.9, 0.999), weight_decay=0.01, eps=1e-6),
+)
+# schedule
+lr_config = dict(policy='poly',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=1e-6,
+ power=0.9, min_lr=1e-6, by_epoch=False)
+
+batchsize_per_gpu = 6
+thread_per_gpu = 4
+
+Argovers2_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Cityscapes_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+DIML_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Lyft_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+DDAD_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ # sample_size = 1200,
+ ),
+ ))
+DSEC_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+DrivingStereo_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+MapillaryPSD_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Pandaset_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Taskonomy_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+UASOL_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Waymo_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Matterport3D_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Replica_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+VKITTI_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+HM3D_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.75, 1.3),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+BlendedMVG_omni_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.75, 1.3),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ ),
+ ))
+ScanNetAll_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Hypersim_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
\ No newline at end of file
diff --git a/training/mono/configs/RAFTDecoder/vit.raft5.small.sanity_check.py b/training/mono/configs/RAFTDecoder/vit.raft5.small.sanity_check.py
new file mode 100644
index 0000000000000000000000000000000000000000..a882418caeeb35a0778c526ed81a771306a775db
--- /dev/null
+++ b/training/mono/configs/RAFTDecoder/vit.raft5.small.sanity_check.py
@@ -0,0 +1,1014 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py',
+
+ '../_base_/datasets/ddad.py',
+ '../_base_/datasets/_data_base_.py',
+ '../_base_/datasets/argovers2.py',
+ '../_base_/datasets/cityscapes.py',
+ '../_base_/datasets/drivingstereo.py',
+ '../_base_/datasets/dsec.py',
+ '../_base_/datasets/lyft.py',
+ '../_base_/datasets/mapillary_psd.py',
+ '../_base_/datasets/diml.py',
+ '../_base_/datasets/taskonomy.py',
+ '../_base_/datasets/uasol.py',
+ '../_base_/datasets/pandaset.py',
+ '../_base_/datasets/waymo.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py',
+
+ '../_base_/datasets/hm3d.py',
+ '../_base_/datasets/matterport3d.py',
+ '../_base_/datasets/replica.py',
+ '../_base_/datasets/vkitti.py',
+ ]
+
+import numpy as np
+model=dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=4,
+ n_downsample=2,
+ detach=False,
+ ),
+)
+
+# loss method
+losses=dict(
+ decoder_losses=[
+ dict(type='VNLoss', sample_ratio=0.2, loss_weight=1.0),
+ dict(type='GRUSequenceLoss', loss_weight=0.5, loss_gamma=0.9, stereo_sup=0.0),
+ dict(type='SkyRegularizationLoss', loss_weight=0.001, sample_ratio=0.4, regress_value=200, normal_regress=[0, 0, -1]),
+ dict(type='HDNRandomLoss', loss_weight=0.5, random_num=10),
+ dict(type='HDSNRandomLoss', loss_weight=0.5, random_num=20, batch_limit=4),
+ dict(type='PWNPlanesLoss', loss_weight=1),
+ dict(type='NormalBranchLoss', loss_weight=1.0, loss_fn='NLL_ours_GRU'),
+ dict(type='DeNoConsistencyLoss', loss_weight=0.01, loss_fn='CEL', scale=2, depth_detach=True)
+ ],
+ gru_losses=[
+ dict(type='SkyRegularizationLoss', loss_weight=0.001, sample_ratio=0.4, regress_value=200, normal_regress=[0, 0, -1]),
+ ],
+)
+
+data_array = [
+ [
+ dict(Matterport3D='Matterport3D_dataset'), #14.4w
+ ],
+]
+
+
+
+# configs of the canonical space
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, 200),
+# crop_size=(544, 1216),
+# crop_size = (544, 992),
+ crop_size = (616, 1064), # %28 = 0
+)
+
+log_interval = 100
+# online evaluation
+# evaluation = dict(online_eval=True, interval=1000, metrics=['abs_rel', 'delta1', 'rmse'], multi_dataset_eval=True)
+interval = 20000
+evaluation = dict(
+ #online_eval=True,
+ online_eval=False,
+ interval=interval,
+ metrics=['abs_rel', 'delta1', 'rmse', 'normal_mean', 'normal_rmse', 'normal_a1'],
+ multi_dataset_eval=True,
+)
+
+# save checkpoint during training, with '*_AMP' is employing the automatic mix precision training
+checkpoint_config = dict(by_epoch=False, interval=interval)
+runner = dict(type='IterBasedRunner_AMP', max_iters=800010)
+
+# optimizer
+optimizer = dict(
+ type='AdamW',
+# encoder=dict(lr=1e-4, betas=(0.9, 0.999), weight_decay=0.01, eps=1e-6),
+ encoder=dict(lr=1e-5, betas=(0.9, 0.999), weight_decay=1e-3, eps=1e-6),
+ decoder=dict(lr=1e-4, betas=(0.9, 0.999), weight_decay=0.01, eps=1e-6),
+)
+# schedule
+lr_config = dict(policy='poly',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=1e-6,
+ power=0.9, min_lr=1e-6, by_epoch=False)
+
+batchsize_per_gpu = 3
+thread_per_gpu = 4
+
+Argovers2_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Cityscapes_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+DIML_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Lyft_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+DDAD_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ # sample_size = 1200,
+ ),
+ ))
+DSEC_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+DrivingStereo_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+MapillaryPSD_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Pandaset_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Taskonomy_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+UASOL_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Waymo_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=True),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Matterport3D_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Replica_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+VKITTI_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+HM3D_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.75, 1.3),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+BlendedMVG_omni_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.75, 1.3),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ ),
+ ))
+ScanNetAll_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
+Hypersim_dataset=dict(
+ data = dict(
+ train=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomResize',
+ prob=0.5,
+ ratio_range=(0.85, 1.15),
+ is_lidar=False),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.05),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ #sample_size = 10000,
+ ),
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_size = 1200,
+ ),
+ ))
\ No newline at end of file
diff --git a/training/mono/configs/__init__.py b/training/mono/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/training/mono/configs/__init__.py
@@ -0,0 +1 @@
+
diff --git a/training/mono/configs/_base_/datasets/7scenes.py b/training/mono/configs/_base_/datasets/7scenes.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d2e42a9bdd2c9e8c2ffb8a6f637c617e978b875
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/7scenes.py
@@ -0,0 +1,83 @@
+# dataset settings
+# data will resized/cropped to the canonical size, refer to ._data_base_.py
+
+SevenScenes_dataset=dict(
+ lib = 'SevenScenesDataset',
+ data_root = 'data/public_datasets',
+ data_name = '7Scenes',
+ transfer_to_canonical = True,
+ metric_scale = 1000.0,
+ original_focal_length = 500,
+ original_size = (480, 640),
+ data_type='denselidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='ETH3D/annotations/test_annotations_new.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='ETH3D/annotations/test_annotations_new.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='AdjustSize',
+ # ignore_label=-1,
+ # padding=[123.675, 116.28, 103.53]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='ETH3D/annotations/test_annotations_new.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[123.675, 116.28, 103.53]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
\ No newline at end of file
diff --git a/training/mono/configs/_base_/datasets/_data_base_.py b/training/mono/configs/_base_/datasets/_data_base_.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f1d339ad89ad1c9a0fec6c5bee928a2462b2eb1
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/_data_base_.py
@@ -0,0 +1,12 @@
+# canonical camera setting and basic data setting
+
+data_basic=dict(
+ canonical_space = dict(
+ img_size=(540, 960),
+ focal_length=1196.0,
+ ),
+ depth_range=(0.9, 150),
+ depth_normalize=(0.006, 1.001),
+ crop_size = (512, 960),
+ clip_depth_range=(0.1, 200),
+)
diff --git a/training/mono/configs/_base_/datasets/argovers2.py b/training/mono/configs/_base_/datasets/argovers2.py
new file mode 100644
index 0000000000000000000000000000000000000000..158841701fa3cf2ddbb8092f9d6992dc760d4735
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/argovers2.py
@@ -0,0 +1,74 @@
+# dataset settings
+
+Argovers2_dataset=dict(
+ lib = 'Argovers2Dataset',
+ data_root = 'data/public_datasets',
+ data_name = 'Argovers2',
+ transfer_to_canonical = True,
+ metric_scale = 200.0,
+ original_focal_length = (1688.844624443858, 1776.8498213965734),
+ original_size = (1550, 2048),
+ data_type='lidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='Argovers2/annotations/train_annotations_wneigh.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='Argovers2/annotations/val_annotations_wneigh.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='Argovers2/annotations/test_annotations_wneigh.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 6000,),
+ ),
+)
diff --git a/training/mono/configs/_base_/datasets/blended_mvg.py b/training/mono/configs/_base_/datasets/blended_mvg.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ee6b8dce6c132dc9293dc7319517e56fe315f43
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/blended_mvg.py
@@ -0,0 +1,78 @@
+# dataset settings
+# data will resized/cropped to the canonical size, refer to ._data_base_.py
+
+BlendedMVG_omni_dataset=dict(
+ lib = 'BlendedMVGOmniDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'BlendedMVG_omni',
+ transfer_to_canonical = True,
+ metric_scale = 512.0,
+ original_focal_length = 575.6656,
+ original_size = (576, 768),
+ data_type='denselidar_nometric',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='BlendedMVG/annotations/train.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.05,),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 50)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='BlendedMVG/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[123.675, 116.28, 103.53]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 5,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='BlendedMVG/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[123.675, 116.28, 103.53]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[123.675, 116.28, 103.53]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
\ No newline at end of file
diff --git a/training/mono/configs/_base_/datasets/cityscapes.py b/training/mono/configs/_base_/datasets/cityscapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff3721ce6751bf159cc929351902730adccedec0
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/cityscapes.py
@@ -0,0 +1,79 @@
+# dataset settings
+
+Cityscapes_dataset=dict(
+ lib = 'CityscapesDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'Cityscapes',
+ transfer_to_canonical = True,
+ metric_scale = 200.0,
+ original_focal_length = (2263.9108952994275, 2263.9108952994275),
+ original_size = (1024, 2048),
+ data_type='stereo',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='Cityscapes_sequence/annotations/train.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='Cityscapes_sequence/annotations/val.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='Cityscapes_sequence/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[123.675, 116.28, 103.53]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
diff --git a/training/mono/configs/_base_/datasets/ddad.py b/training/mono/configs/_base_/datasets/ddad.py
new file mode 100644
index 0000000000000000000000000000000000000000..522dc563fb639d3254eb116f247aecdce0e79c7d
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/ddad.py
@@ -0,0 +1,80 @@
+# dataset settings
+
+DDAD_dataset=dict(
+ lib = 'DDADDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'DDAD',
+ transfer_to_canonical = True,
+ metric_scale = 200.0,
+ original_focal_length = (2181, 1060),
+ original_size = (1216, 1936),
+ data_type='lidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='DDAD/annotations/train_annotations.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='DDAD/annotations/val_annotations.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='DDAD/annotations/test_annotations.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ # dict(type='LabelScaleCononical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960), #(1216, 1952), #
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[123.675, 116.28, 103.53]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 800,),
+ ),
+)
diff --git a/training/mono/configs/_base_/datasets/ddad_any.py b/training/mono/configs/_base_/datasets/ddad_any.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dc24d84df26cd4b778f21ab65775abd453853d1
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/ddad_any.py
@@ -0,0 +1,79 @@
+# dataset settings
+
+DDADAny_dataset=dict(
+ lib = 'AnyDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'DDAD',
+ transfer_to_canonical = True,
+ metric_scale = 200.0,
+ original_focal_length = (2181, 1060),
+ original_size = (1216, 1936),
+ data_type='lidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='DDAD/annotations/train_annotations.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='DDAD/annotations/val_annotations.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='DDAD/annotations/test_annotations.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 6000,),
+ ),
+)
diff --git a/training/mono/configs/_base_/datasets/diml.py b/training/mono/configs/_base_/datasets/diml.py
new file mode 100644
index 0000000000000000000000000000000000000000..71fe2a7741f9a0871b184eb722f3906bd7860202
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/diml.py
@@ -0,0 +1,79 @@
+# dataset settings
+
+DIML_dataset=dict(
+ lib = 'DIMLDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'DIML',
+ transfer_to_canonical = True,
+ metric_scale = 200.0,
+ original_focal_length = (1398.402, ),
+ original_size = (1080, 1920),
+ data_type='stereo',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='DIML/annotations/train.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='DIML/annotations/val.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='DIML/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[123.675, 116.28, 103.53]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
diff --git a/training/mono/configs/_base_/datasets/diml_indoor.py b/training/mono/configs/_base_/datasets/diml_indoor.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c2721effc6317c402c289a803dfa591b440970e
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/diml_indoor.py
@@ -0,0 +1,76 @@
+# dataset settings
+
+DIML_indoor_dataset=dict(
+ lib = 'DIMLDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'DIML_indoor',
+ metric_scale = 1000.0,
+ data_type='stereo_nocamera',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='DIML/annotations/train.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='DIML/annotations/val.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='DIML/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[123.675, 116.28, 103.53]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
diff --git a/training/mono/configs/_base_/datasets/diode.py b/training/mono/configs/_base_/datasets/diode.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6a8de74e6f3101e7d9a39721dbe6eb132c68eed
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/diode.py
@@ -0,0 +1,80 @@
+# dataset settings
+# data will resized/cropped to the canonical size, refer to ._data_base_.py
+
+DIODE_dataset=dict(
+ lib = 'DIODEDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'DIODE',
+ transfer_to_canonical = True,
+ metric_scale = 1.0,
+ original_focal_length = 886.81,
+ original_size = (764, 1024),
+ data_type='denselidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='DIODE/annotations/train.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='DIODE/annotations/val.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 50,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='DIODE/annotations/test_annotations_new.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[123.675, 116.28, 103.53]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
\ No newline at end of file
diff --git a/training/mono/configs/_base_/datasets/drivingstereo.py b/training/mono/configs/_base_/datasets/drivingstereo.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f770a7adb692a28dd621eb174361cd46e13d20a
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/drivingstereo.py
@@ -0,0 +1,79 @@
+# dataset settings
+
+DrivingStereo_dataset=dict(
+ lib = 'DrivingStereoDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'DrivingStereo',
+ transfer_to_canonical = True,
+ metric_scale = 256.0,
+ original_focal_length = (1006.938, 1003.556),
+ original_size = (400, 881),
+ data_type='lidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='DrivingStereo/annotations/train_annotations.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='DrivingStereo/annotations/val_annotations.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='DrivingStereo/annotations/test_annotations.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[123.675, 116.28, 103.53]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
diff --git a/training/mono/configs/_base_/datasets/dsec.py b/training/mono/configs/_base_/datasets/dsec.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d1bbcd05f6194f583d39d7b26193860f966faf8
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/dsec.py
@@ -0,0 +1,79 @@
+# dataset settings
+
+DSEC_dataset=dict(
+ lib = 'DSECDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'DSEC',
+ transfer_to_canonical = True,
+ metric_scale = 200.0,
+ original_focal_length = (1150.8943600390282, ),
+ original_size = (1080, 1440),
+ data_type='lidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='DSEC/annotations/train_annotations_wtmpl.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='DSEC/annotations/val_annotations_wtmpl.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='DSEC/annotations/test_annotations_wtmpl.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[123.675, 116.28, 103.53]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
diff --git a/training/mono/configs/_base_/datasets/eth3d.py b/training/mono/configs/_base_/datasets/eth3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..660db92b301cf48f800b1551ed268b7169dec64a
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/eth3d.py
@@ -0,0 +1,80 @@
+# dataset settings
+# data will resized/cropped to the canonical size, refer to ._data_base_.py
+
+ETH3D_dataset=dict(
+ lib = 'ETH3DDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'ETH3D',
+ transfer_to_canonical = True,
+ metric_scale = 1.0,
+ original_focal_length = 886.81,
+ original_size = (764, 1024),
+ data_type='lidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='ETH3D/annotations/test_annotations_new.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='ETH3D/annotations/test_annotations_new.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='ETH3D/annotations/test_annotations_new.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[123.675, 116.28, 103.53]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
\ No newline at end of file
diff --git a/training/mono/configs/_base_/datasets/hm3d.py b/training/mono/configs/_base_/datasets/hm3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..c800a616668066b1a8feeaeffdadd6a0e4cd2298
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/hm3d.py
@@ -0,0 +1,78 @@
+# dataset settings
+# data will resized/cropped to the canonical size, refer to ._data_base_.py
+
+HM3D_dataset=dict(
+ lib = 'HM3DDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'HM3D',
+ transfer_to_canonical = True,
+ metric_scale = 512.0,
+ original_focal_length = 575.6656,
+ original_size = (512, 512),
+ data_type='denselidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='HM3D/annotations/train.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.2)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.0,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.05,),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 50)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='HM3D/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='HM3D/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[123.675, 116.28, 103.53]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
\ No newline at end of file
diff --git a/training/mono/configs/_base_/datasets/hypersim.py b/training/mono/configs/_base_/datasets/hypersim.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6cf4e2ad272d110f2b4b275a31b0683cefc715e
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/hypersim.py
@@ -0,0 +1,71 @@
+# dataset settings
+# data will resized/cropped to the canonical size, refer to ._data_base_.py
+
+Hypersim_dataset=dict(
+ lib = 'HypersimDataset',
+ data_name = 'Hypersim',
+ metric_scale = 1.0,
+ data_type='denselidar_syn',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.3)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.0,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.05,),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 50)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 200,),
+ # configs for the training pipeline
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 2000,),
+ ),
+)
\ No newline at end of file
diff --git a/training/mono/configs/_base_/datasets/ibims.py b/training/mono/configs/_base_/datasets/ibims.py
new file mode 100644
index 0000000000000000000000000000000000000000..0851029095748b90bf9d1b6c4b7cd03b17f2f345
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/ibims.py
@@ -0,0 +1,80 @@
+# dataset settings
+# data will resized/cropped to the canonical size, refer to ._data_base_.py
+
+IBIMS_dataset=dict(
+ lib = 'IBIMSDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'IBIMS',
+ transfer_to_canonical = True,
+ metric_scale = 1000.0,
+ original_focal_length = 518.857,
+ original_size = (480, 640),
+ data_type='lidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='IBIMS/annotations/train.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='IBIMS/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='IBIMS/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[123.675, 116.28, 103.53]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
\ No newline at end of file
diff --git a/training/mono/configs/_base_/datasets/kitti.py b/training/mono/configs/_base_/datasets/kitti.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d68f806bea0333c6b6eecfb99c9384adfef2023
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/kitti.py
@@ -0,0 +1,80 @@
+# dataset settings
+# data will resized/cropped to the canonical size, refer to ._data_base_.py
+
+KITTI_dataset=dict(
+ lib = 'KITTIDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'KITTI',
+ transfer_to_canonical = True,
+ metric_scale = 256.0,
+ original_focal_length = 518.857,
+ original_size = (480, 640),
+ data_type='lidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='KITTI/annotations/train.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='KITTI/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='KITTI/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[123.675, 116.28, 103.53]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
\ No newline at end of file
diff --git a/training/mono/configs/_base_/datasets/leddarpixset.py b/training/mono/configs/_base_/datasets/leddarpixset.py
new file mode 100644
index 0000000000000000000000000000000000000000..27eb3e6d04397792c9a5ed3e3afc9b6c5b827b00
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/leddarpixset.py
@@ -0,0 +1,80 @@
+# dataset settings
+
+LeddarPixSet_dataset=dict(
+ lib = 'LeddarPixSetDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'LeddarPixSet',
+ transfer_to_canonical = True,
+ metric_scale = 200.0,
+ original_focal_length = (2181, 1060),
+ original_size = (1080, 1440),
+ data_type='lidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='LeddarPixSet/annotations/train_annotations.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='LeddarPixSet/annotations/val_annotations.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 50,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='LeddarPixSet/annotations/test_annotations.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ # dict(type='LabelScaleCononical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960), #(1216, 1952), #
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[123.675, 116.28, 103.53]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
diff --git a/training/mono/configs/_base_/datasets/lyft.py b/training/mono/configs/_base_/datasets/lyft.py
new file mode 100644
index 0000000000000000000000000000000000000000..5917ec9fb5e820834257615267360337c7530b4b
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/lyft.py
@@ -0,0 +1,79 @@
+# dataset settings
+
+Lyft_dataset=dict(
+ lib = 'LyftDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'Lyft',
+ transfer_to_canonical = True,
+ metric_scale = 200.0,
+ original_focal_length = (877.406430795, 3416.79, 1108.782, 3986.358, 3427.04, ),
+ original_size = (1024, 1224),
+ data_type='lidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='Lyft/annotations/train_annotations_wtmpl.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='Lyft/annotations/val_annotations_wtmpl.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='Lyft/annotations/test_annotations_wtmpl.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 6000,),
+ ),
+)
diff --git a/training/mono/configs/_base_/datasets/lyft_any.py b/training/mono/configs/_base_/datasets/lyft_any.py
new file mode 100644
index 0000000000000000000000000000000000000000..5775563e8462922168257b240b0d2c2ce9d22214
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/lyft_any.py
@@ -0,0 +1,79 @@
+# dataset settings
+
+LyftAny_dataset=dict(
+ lib = 'AnyDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'Lyft',
+ transfer_to_canonical = True,
+ metric_scale = 200.0,
+ original_focal_length = (877.406430795, 880.82631362),
+ original_size = (1024, 1224),
+ data_type='lidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='Lyft/annotations/train_annotations_wtmpl.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='Lyft/annotations/val_annotations_wtmpl.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='Lyft/annotations/test_annotations_wtmpl.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[123.675, 116.28, 103.53]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 6000,),
+ ),
+)
diff --git a/training/mono/configs/_base_/datasets/mapillary_psd.py b/training/mono/configs/_base_/datasets/mapillary_psd.py
new file mode 100644
index 0000000000000000000000000000000000000000..744e246d4e7832fd60eb9695d33dd873205cae5d
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/mapillary_psd.py
@@ -0,0 +1,79 @@
+# dataset settings
+
+MapillaryPSD_dataset=dict(
+ lib = 'MapillaryPSDDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'MapillaryPSD',
+ transfer_to_canonical = True,
+ metric_scale = 256.0,
+ original_focal_length = (1664.38, 1725.494, 1231.4812, 2576.447),
+ original_size = (1536, 2048),
+ data_type='sfm',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='Mapillary_PSD/annotations/train_annotations.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriten by data_basic configs
+ crop_type='rand', # center, rand, rand_in_field
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='Mapillary_PSD/annotations/val_annotations.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='Mapillary_PSD/annotations/test_annotations.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
diff --git a/training/mono/configs/_base_/datasets/matterport3d.py b/training/mono/configs/_base_/datasets/matterport3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1d3b5a8da21720850b77705c9488a5adef5d741
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/matterport3d.py
@@ -0,0 +1,78 @@
+# dataset settings
+# data will resized/cropped to the canonical size, refer to ._data_base_.py
+
+Matterport3D_dataset=dict(
+ lib = 'Matterport3DDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'Matterport3D',
+ transfer_to_canonical = True,
+ metric_scale = 4000.0,
+ original_focal_length = 575.6656,
+ original_size = (1024, 1280),
+ data_type='denselidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='Matterport3D/annotations/test.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.05,),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 50)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='Matterport3D/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='Matterport3D/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
\ No newline at end of file
diff --git a/training/mono/configs/_base_/datasets/nuscenes.py b/training/mono/configs/_base_/datasets/nuscenes.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d47b3937d501929c1efdba25030ef4e6744feb4
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/nuscenes.py
@@ -0,0 +1,79 @@
+# dataset settings
+
+NuScenes_dataset=dict(
+ lib = 'NuScenesDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'NuScenes',
+ transfer_to_canonical = True,
+ metric_scale = 200.0,
+ original_focal_length = (877.406430795, 1200.82631362),
+ original_size = (1024, 1224),
+ data_type='lidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='NuScenes/annotations/train_annotations_wtmpl.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='NuScenes/annotations/val_annotations_wtmpl.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='NuScenes/annotations/test_annotations_wtmpl.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
diff --git a/training/mono/configs/_base_/datasets/nuscenes_any.py b/training/mono/configs/_base_/datasets/nuscenes_any.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b1af09a1eecd9a3db11bc9596a439cecc4e58fb
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/nuscenes_any.py
@@ -0,0 +1,79 @@
+# dataset settings
+
+NuScenesAny_dataset=dict(
+ lib = 'AnyDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'NuScenes',
+ transfer_to_canonical = True,
+ metric_scale = 200.0,
+ original_focal_length = (877.406430795, 1200.82631362),
+ original_size = (1024, 1224),
+ data_type='lidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='NuScenes/annotations/train_annotations_wtmpl.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='NuScenes/annotations/val_annotations_wtmpl.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='NuScenes/annotations/test_annotations_wtmpl.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
diff --git a/training/mono/configs/_base_/datasets/nyu.py b/training/mono/configs/_base_/datasets/nyu.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5e81e07893e30daf05ba5ce644e3c9ab6000330
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/nyu.py
@@ -0,0 +1,80 @@
+# dataset settings
+# data will resized/cropped to the canonical size, refer to ._data_base_.py
+
+NYU_dataset=dict(
+ lib = 'NYUDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'NYU',
+ transfer_to_canonical = True,
+ metric_scale = 6000.0,
+ original_focal_length = 518.857,
+ original_size = (480, 640),
+ data_type='lidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='NYU/annotations/train.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='NYU/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='NYU/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
\ No newline at end of file
diff --git a/training/mono/configs/_base_/datasets/pandaset.py b/training/mono/configs/_base_/datasets/pandaset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e59ed9fc9a9676f42abe2e6665ce6a801e4f9d0
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/pandaset.py
@@ -0,0 +1,79 @@
+# dataset settings
+
+Pandaset_dataset=dict(
+ lib = 'PandasetDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'Pandaset',
+ transfer_to_canonical = True,
+ metric_scale = 200.0,
+ original_focal_length = (1970.01, 930.45, 929.84),
+ original_size = (1080, 1920),
+ data_type='lidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='Pandaset/annotations/annotations_train.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='Pandaset/annotations/annotations_val.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='Pandaset/annotations/annotations_test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 800,),
+ ),
+)
diff --git a/training/mono/configs/_base_/datasets/replica.py b/training/mono/configs/_base_/datasets/replica.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bd849813ea0894875aee1c51d36a9bd269ab3d6
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/replica.py
@@ -0,0 +1,78 @@
+# dataset settings
+# data will resized/cropped to the canonical size, refer to ._data_base_.py
+
+Replica_dataset=dict(
+ lib = 'ReplicaDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'Replica',
+ transfer_to_canonical = True,
+ metric_scale = 512.0,
+ original_focal_length = 575.6656,
+ original_size = (512, 512),
+ data_type='denselidar_syn',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='Replica/annotations/test.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.05,),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 50)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='Replica/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 50,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='Replica/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 2000,),
+ ),
+)
\ No newline at end of file
diff --git a/training/mono/configs/_base_/datasets/scannet.py b/training/mono/configs/_base_/datasets/scannet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ce2390bb1e4444cf6c24d75f4a04ef1407fd1b1
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/scannet.py
@@ -0,0 +1,80 @@
+# dataset settings
+# data will resized/cropped to the canonical size, refer to ._data_base_.py
+
+ScanNet_dataset=dict(
+ lib = 'ScanNetDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'ScanNet',
+ transfer_to_canonical = True,
+ metric_scale = 1000.0,
+ original_focal_length = 1165.371094,
+ original_size = (968, 1296),
+ data_type='lidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='ScanNet/annotations/test.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='ScanNet/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='ScanNet/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
\ No newline at end of file
diff --git a/training/mono/configs/_base_/datasets/scannet_all.py b/training/mono/configs/_base_/datasets/scannet_all.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa1e025af160f18b617a1a6c8c02fd1c5f773655
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/scannet_all.py
@@ -0,0 +1,80 @@
+# dataset settings
+# data will resized/cropped to the canonical size, refer to ._data_base_.py
+
+ScanNetAll_dataset=dict(
+ lib = 'ScanNetDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'ScanNetAll',
+ transfer_to_canonical = True,
+ metric_scale = 1000.0,
+ original_focal_length = 1165.371094,
+ original_size = (968, 1296),
+ data_type='denselidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='ScanNet/annotations/test.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='ScanNet/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='ScanNet/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
\ No newline at end of file
diff --git a/training/mono/configs/_base_/datasets/taskonomy.py b/training/mono/configs/_base_/datasets/taskonomy.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7ad3f1053ae4556905403b76a8d810c4d787afc
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/taskonomy.py
@@ -0,0 +1,78 @@
+# dataset settings
+# data will resized/cropped to the canonical size, refer to ._data_base_.py
+
+Taskonomy_dataset=dict(
+ lib = 'TaskonomyDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'Taskonomy',
+ transfer_to_canonical = True,
+ metric_scale = 512.0,
+ original_focal_length = 575.6656,
+ original_size = (512, 512),
+ data_type='denselidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='Taskonomy/annotations/test.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.3)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.0,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.1,
+ distortion_prob=0.05,),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 50)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='Taskonomy/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 20,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='Taskonomy/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 2000,),
+ ),
+)
\ No newline at end of file
diff --git a/training/mono/configs/_base_/datasets/uasol.py b/training/mono/configs/_base_/datasets/uasol.py
new file mode 100644
index 0000000000000000000000000000000000000000..b80efd1c60ccf252d92ce946728ba8c5fc0a83a9
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/uasol.py
@@ -0,0 +1,74 @@
+# dataset settings
+
+UASOL_dataset=dict(
+ lib = 'UASOLDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'UASOL',
+ transfer_to_canonical = True,
+ metric_scale = 200.0,
+ original_focal_length = (2263.9108952994275, 2263.9108952994275),
+ original_size = (1024, 2048),
+ data_type='stereo',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='UASOL/annotations/train.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='UASOL/annotations/test_all.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 100,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='UASOL/annotations/test_all.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
diff --git a/training/mono/configs/_base_/datasets/vkitti.py b/training/mono/configs/_base_/datasets/vkitti.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2f7b5b39d0ab7237f0b64fecc4190fa8ac497d5
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/vkitti.py
@@ -0,0 +1,80 @@
+# dataset settings
+# data will resized/cropped to the canonical size, refer to ._data_base_.py
+
+VKITTI_dataset=dict(
+ lib = 'VKITTIDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'VKITTI',
+ transfer_to_canonical = True,
+ metric_scale = 100.0,
+ original_focal_length = 725.0087,
+ original_size = (375, 1242),
+ data_type='denselidar_syn',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='VKITTI/annotations/train.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='VKITTI/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 50,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='VKITTI/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
\ No newline at end of file
diff --git a/training/mono/configs/_base_/datasets/waymo.py b/training/mono/configs/_base_/datasets/waymo.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ac9d95fc15a9be431a044d0fad7d391b6d6ab10
--- /dev/null
+++ b/training/mono/configs/_base_/datasets/waymo.py
@@ -0,0 +1,80 @@
+# dataset settings
+# data will resized/cropped to the canonical size, refer to ._data_base_.py
+
+Waymo_dataset=dict(
+ lib = 'WaymoDataset',
+ data_root = 'data/public_datasets',
+ data_name = 'Waymo',
+ transfer_to_canonical = True,
+ metric_scale = 200.0,
+ original_focal_length = 2000.8,
+ original_size = (2000, 2000),
+ data_type='lidar',
+ data = dict(
+ # configs for the training pipeline
+ train=dict(
+ anno_path='Waymo/annotations/train.json',
+ sample_ratio = 1.0,
+ sample_size = -1,
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(0.9, 1.4)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='rand',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='RandomEdgeMask',
+ mask_maxsize=50,
+ prob=0.2,
+ rgb_invalid=[0,0,0],
+ label_invalid=-1,),
+ dict(type='RandomHorizontalFlip',
+ prob=0.4),
+ dict(type='PhotoMetricDistortion',
+ to_gray_prob=0.2,
+ distortion_prob=0.1,),
+ dict(type='Weather',
+ prob=0.1),
+ dict(type='RandomBlur',
+ prob=0.05),
+ dict(type='RGBCompresion', prob=0.1, compression=(0, 40)),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],),
+
+ # configs for the training pipeline
+ val=dict(
+ anno_path='Waymo/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='ResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='RandomCrop',
+ crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ crop_type='center',
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 50,),
+ # configs for the training pipeline
+ test=dict(
+ anno_path='Waymo/annotations/test.json',
+ pipeline=[dict(type='BGR2RGB'),
+ # dict(type='LiDarResizeCanonical', ratio_range=(1.0, 1.0)),
+ dict(type='ResizeKeepRatio',
+ resize_size=(512, 960),
+ ignore_label=-1,
+ padding=[0, 0, 0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0), # crop_size will be overwriteen by data_basic configs
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0, 0, 0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,),
+ ),
+)
\ No newline at end of file
diff --git a/training/mono/configs/_base_/default_runtime.py b/training/mono/configs/_base_/default_runtime.py
new file mode 100644
index 0000000000000000000000000000000000000000..4815a5c0c6bce22f2b8a499f033de971f146aeda
--- /dev/null
+++ b/training/mono/configs/_base_/default_runtime.py
@@ -0,0 +1,23 @@
+# distributed training configs, if dist_url == 'env://'('tcp://127.0.0.1:6795'), nodes related configs should be set in the shell
+dist_params = dict(port=None, backend='nccl', dist_url='env://')
+
+log_name = 'tbd'
+log_file = 'out.log'
+
+load_from = None
+resume_from = None
+
+#workflow = [('train', 1)]
+cudnn_benchmark = True
+log_interval = 20
+
+use_tensorboard = True
+
+evaluation = dict(online_eval=True, interval=1000, metrics=['abs_rel', 'delta1'])
+checkpoint_config = dict(by_epoch=False, interval=16000)
+
+
+# runtime settings, IterBasedRunner or EpochBasedRunner, e.g. runner = dict(type='EpochBasedRunner', max_epoches=100)
+runner = dict(type='IterBasedRunner', max_iters=160000)
+
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'rmse_log', 'log10', 'sq_rel']
\ No newline at end of file
diff --git a/training/mono/configs/_base_/losses/all_losses.py b/training/mono/configs/_base_/losses/all_losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0ad857e7b2f859e72f9bf4556a97e3d6bed6326
--- /dev/null
+++ b/training/mono/configs/_base_/losses/all_losses.py
@@ -0,0 +1,26 @@
+"""
+There are multiple losses can be applied.
+
+dict(type='GradientLoss_Li', scale_num=4, loss_weight=1.0),
+dict(type='VNLoss', sample_ratio=0.2, loss_weight=1.0),
+dict(type='SilogLoss', variance_focus=0.5, loss_weight=1.0),
+dict(type='WCELoss', loss_weight=1.0, depth_normalize=(0.1, 1), bins_num=200)
+dict(type='RegularizationLoss', loss_weight=0.1)
+dict(type='EdgeguidedRankingLoss', loss_weight=1.0)
+Note that out_channel and depth_normalize will be overwriten by configs in data_basic.
+"""
+
+# loss_decode=[dict(type='VNLoss', sample_ratio=0.2, loss_weight=1.0),
+# #dict(type='SilogLoss', variance_focus=0.5, loss_weight=1.0),
+# dict(type='WCELoss', loss_weight=1.0, depth_normalize=(0, 0), out_channel=0)]
+
+# loss_auxi = [#dict(type='WCELoss', loss_weight=1.0, depth_normalize=(0.1, 1), out_channel=200),
+# ]
+losses=dict(
+ decoder_losses=[
+ dict(type='VNLoss', sample_ratio=0.2, loss_weight=1.0),
+ dict(type='WCELoss', loss_weight=1.0, depth_normalize=(0, 0), out_channel=0),
+ ],
+ auxi_losses=[],
+ pose_losses=[],
+)
diff --git a/training/mono/configs/_base_/models/backbones/dino_vit_giant2_reg.py b/training/mono/configs/_base_/models/backbones/dino_vit_giant2_reg.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c1ebc96ceaa32ad9310d3b84d55d252be843c46
--- /dev/null
+++ b/training/mono/configs/_base_/models/backbones/dino_vit_giant2_reg.py
@@ -0,0 +1,7 @@
+model = dict(
+ backbone=dict(
+ type='vit_giant2_reg',
+ prefix='backbones.',
+ out_channels=[1536, 1536, 1536, 1536],
+ drop_path_rate = 0.0),
+ )
diff --git a/training/mono/configs/_base_/models/backbones/dino_vit_large_reg.py b/training/mono/configs/_base_/models/backbones/dino_vit_large_reg.py
new file mode 100644
index 0000000000000000000000000000000000000000..25e96747d459d42df299f8a6a1e14044a0e56164
--- /dev/null
+++ b/training/mono/configs/_base_/models/backbones/dino_vit_large_reg.py
@@ -0,0 +1,7 @@
+model = dict(
+ backbone=dict(
+ type='vit_large_reg',
+ prefix='backbones.',
+ out_channels=[1024, 1024, 1024, 1024],
+ drop_path_rate = 0.0),
+ )
diff --git a/training/mono/configs/_base_/models/backbones/dino_vit_small_reg.py b/training/mono/configs/_base_/models/backbones/dino_vit_small_reg.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c8bd97dccb9cdee7517250f40e01bb3124144e6
--- /dev/null
+++ b/training/mono/configs/_base_/models/backbones/dino_vit_small_reg.py
@@ -0,0 +1,7 @@
+model = dict(
+ backbone=dict(
+ type='vit_small_reg',
+ prefix='backbones.',
+ out_channels=[384, 384, 384, 384],
+ drop_path_rate = 0.0),
+ )
diff --git a/training/mono/configs/_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py b/training/mono/configs/_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..73702d298c05979bcdf013e9c30ec56f4e36665b
--- /dev/null
+++ b/training/mono/configs/_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py
@@ -0,0 +1,19 @@
+# model settings
+_base_ = ['../backbones/dino_vit_giant2_reg.py']
+model = dict(
+ type='DensePredModel',
+ decode_head=dict(
+ type='RAFTDepthDPT',
+ in_channels=[1536, 1536, 1536, 1536],
+ use_cls_token=True,
+ feature_channels = [384, 768, 1536, 1536], # [2/7, 1/7, 1/14, 1/14]
+ decoder_channels = [192, 384, 768, 1536, 1536], # [4/7, 2/7, 1/7, 1/14, 1/14]
+ up_scale = 7,
+ hidden_channels=[192, 192, 192, 192], # [x_4, x_8, x_16, x_32] [192, 384, 768, 1536]
+ n_gru_layers=3,
+ n_downsample=2,
+ iters=3,
+ slow_fast_gru=True,
+ num_register_tokens=4,
+ prefix='decode_heads.'),
+)
diff --git a/training/mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py b/training/mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..26ab6dc090e9cdb840d84fab10587becb536dbb8
--- /dev/null
+++ b/training/mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py
@@ -0,0 +1,19 @@
+# model settings
+_base_ = ['../backbones/dino_vit_large_reg.py']
+model = dict(
+ type='DensePredModel',
+ decode_head=dict(
+ type='RAFTDepthDPT',
+ in_channels=[1024, 1024, 1024, 1024],
+ use_cls_token=True,
+ feature_channels = [256, 512, 1024, 1024], # [2/7, 1/7, 1/14, 1/14]
+ decoder_channels = [128, 256, 512, 1024, 1024], # [4/7, 2/7, 1/7, 1/14, 1/14]
+ up_scale = 7,
+ hidden_channels=[128, 128, 128, 128], # [x_4, x_8, x_16, x_32] [192, 384, 768, 1536]
+ n_gru_layers=3,
+ n_downsample=2,
+ iters=3,
+ slow_fast_gru=True,
+ num_register_tokens=4,
+ prefix='decode_heads.'),
+)
diff --git a/training/mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py b/training/mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..19466c191e9f2a83903e55ca4fc0827d9a11bcb9
--- /dev/null
+++ b/training/mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py
@@ -0,0 +1,19 @@
+# model settings
+_base_ = ['../backbones/dino_vit_small_reg.py']
+model = dict(
+ type='DensePredModel',
+ decode_head=dict(
+ type='RAFTDepthDPT',
+ in_channels=[384, 384, 384, 384],
+ use_cls_token=True,
+ feature_channels = [96, 192, 384, 768], # [2/7, 1/7, 1/14, 1/14]
+ decoder_channels = [48, 96, 192, 384, 384], # [-, 1/4, 1/7, 1/14, 1/14]
+ up_scale = 7,
+ hidden_channels=[48, 48, 48, 48], # [x_4, x_8, x_16, x_32] [1/4, 1/7, 1/14, -]
+ n_gru_layers=3,
+ n_downsample=2,
+ iters=3,
+ slow_fast_gru=True,
+ num_register_tokens=4,
+ prefix='decode_heads.'),
+)
diff --git a/training/mono/configs/_base_/schedules/schedule_1m.py b/training/mono/configs/_base_/schedules/schedule_1m.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b347f377bbe5751d8b24919d0e3eeb98b7d3900
--- /dev/null
+++ b/training/mono/configs/_base_/schedules/schedule_1m.py
@@ -0,0 +1,9 @@
+optimizer = dict(
+ type='SGD',
+ encoder=dict(lr=0.01, ),
+ decoder=dict(lr=0.01, ),
+)
+# learning policy
+lr_config = dict(policy='poly',) #dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
+
+
diff --git a/training/mono/configs/test_configs_vit/ddad.vit.dpt.raft.py b/training/mono/configs/test_configs_vit/ddad.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..8451071744d8f0cd0b1e7dcaaf4a7ce48f9157b0
--- /dev/null
+++ b/training/mono/configs/test_configs_vit/ddad.vit.dpt.raft.py
@@ -0,0 +1,94 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py',
+
+ '../_base_/datasets/ddad.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(DDAD='DDAD_dataset'),
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0,1),
+ depth_normalize=(0.1, 200),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064),
+)
+
+
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3']
+DDAD_dataset=dict(
+ data = dict(
+ test=dict(
+ anno_path='DDAD/annotations/test_annotations.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ # resize_size=(1216, 1952), #(544, 992), #
+ # resize_size=(560, 1008),
+ # resize_size=(840, 1512),
+ resize_size=(616,1064),
+ ignore_label=-1,
+ padding=[0,0,0]),
+ # dict(type='ResizeKeepRatio',
+ # resize_size=(1120, 2016),
+ # ignore_label=-1,
+ # padding=[0,0,0],
+ # keep_gt=True),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0),
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 500,
+ ),
+ ))
+
+# DDAD_dataset=dict(
+# data = dict(
+# test=dict(
+# anno_path='DDAD/annotations/test_annotations.json',
+# pipeline=[dict(type='BGR2RGB'),
+# dict(type='KeepResizeCanoSize',
+# resize_size=(640, 1088), #(1216, 1952), #(512, 960), #
+# ignore_label=-1,
+# padding=[0, 0, 0]),
+# dict(type='ToTensor'),
+# dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+# ],
+# sample_ratio = 1.0,
+# sample_size = 80,
+# ),
+# ))
\ No newline at end of file
diff --git a/training/mono/configs/test_configs_vit/diode.vit.dpt.raft.py b/training/mono/configs/test_configs_vit/diode.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..88e40fac91d4c36f87dbfe9394cfbdbfaea4dbc1
--- /dev/null
+++ b/training/mono/configs/test_configs_vit/diode.vit.dpt.raft.py
@@ -0,0 +1,66 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py',
+
+ '../_base_/datasets/diode.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model=dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(DIODE='DIODE_dataset'),
+ #dict(DIODE_indoor='DIODE_dataset')
+ #dict(DIODE_outdoor='DIODE_dataset')
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, 200),# (0.3, 160),
+ # crop_size = (512, 960),
+ clip_depth_range=(0.1, 150),
+)
+
+
+
+# indoor (544, 928), outdoor: (768, 1088)
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'normal_median' , 'normal_mean', 'normal_rmse', 'normal_a1', 'normal_a2', 'normal_a3', 'normal_a4', 'normal_a5']
+DIODE_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ resize_size=(616, 1064), #(544, 992), #(768, 1088), #(768, 1120), # (768, 1216), #(768, 1024), # (768, 1216), #(768, 1312), #
+ ignore_label=-1,
+ padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,
+ ),
+ ))
diff --git a/training/mono/configs/test_configs_vit/eth3d.vit.dpt.raft.py b/training/mono/configs/test_configs_vit/eth3d.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..a65ee5d3c1320916f0200fe071cc6e586f128ae5
--- /dev/null
+++ b/training/mono/configs/test_configs_vit/eth3d.vit.dpt.raft.py
@@ -0,0 +1,70 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py',
+
+ '../_base_/datasets/eth3d.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(ETH3D='ETH3D_dataset'), #447.2w
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, 200),# (0.3, 160),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064),
+)
+
+# indoor (544, 928), outdoor: (768, 1088)
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'normal_mean', 'normal_rmse', 'normal_a1']
+ETH3D_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ # resize_size=(512, 512), #(768, 1088), #(768, 1120), # (768, 1216), #(768, 1024), # (768, 1216), #(768, 1312), # (512, 512)
+ resize_size=(616,1064),
+ # resize_size=(1120, 2016),
+ ignore_label=-1,
+ padding=[0,0,0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0),
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,
+ ),
+ ))
diff --git a/training/mono/configs/test_configs_vit/ibims.vit.dpt.raft.py b/training/mono/configs/test_configs_vit/ibims.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..411ed1b5777d272816c7846564f23256e7dca222
--- /dev/null
+++ b/training/mono/configs/test_configs_vit/ibims.vit.dpt.raft.py
@@ -0,0 +1,71 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py',
+
+ '../_base_/datasets/ibims.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(IBIMS='IBIMS_dataset'), #447.2w
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, 200),# (0.3, 160),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 10),
+ vit_size=(616,1064),
+)
+clip_depth = True
+
+# indoor (544, 928), outdoor: (768, 1088)
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'normal_mean', 'normal_rmse', 'normal_a3', 'normal_a4', 'normal_a5', 'normal_median']
+IBIMS_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ # resize_size=(512, 512), #(768, 1088), #(768, 1120), # (768, 1216), #(768, 1024), # (768, 1216), #(768, 1312), # (512, 512)
+ resize_size=(616,1064),
+ # resize_size=(1120, 2016),
+ ignore_label=-1,
+ padding=[0,0,0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0),
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,
+ ),
+ ))
diff --git a/training/mono/configs/test_configs_vit/kitti.vit.dpt.raft.py b/training/mono/configs/test_configs_vit/kitti.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..da061756ccdad39dd5a5748a21d94ba97bef8b66
--- /dev/null
+++ b/training/mono/configs/test_configs_vit/kitti.vit.dpt.raft.py
@@ -0,0 +1,82 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py',
+
+ '../_base_/datasets/kitti.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(KITTI='KITTI_dataset'),
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0,1),
+ depth_normalize=(0.1, 200),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 80),
+ vit_size=(616,1064),
+)
+
+clip_depth = True
+
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'rmse_log',
+ 'log10']
+KITTI_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ resize_size=(616, 1064), #(416, 1248), #(480, 1216), #(512, 1088), #(512, 1312), #(480, 1248), # #
+ ignore_label=-1,
+ padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,
+ ),
+ ))
+
+# DDAD_dataset=dict(
+# data = dict(
+# test=dict(
+# anno_path='DDAD/annotations/test_annotations.json',
+# pipeline=[dict(type='BGR2RGB'),
+# dict(type='KeepResizeCanoSize',
+# resize_size=(640, 1088), #(1216, 1952), #(512, 960), #
+# ignore_label=-1,
+# padding=[0, 0, 0]),
+# dict(type='ToTensor'),
+# dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+# ],
+# sample_ratio = 1.0,
+# sample_size = 80,
+# ),
+# ))
\ No newline at end of file
diff --git a/training/mono/configs/test_configs_vit/nuscenes.vit.dpt.raft.py b/training/mono/configs/test_configs_vit/nuscenes.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..c81aebc1da766c67db1fc3cda9421a3fe4f6ade3
--- /dev/null
+++ b/training/mono/configs/test_configs_vit/nuscenes.vit.dpt.raft.py
@@ -0,0 +1,93 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py',
+
+ '../_base_/datasets/nuscenes.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(NuScenes='NuScenes_dataset'),
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0,1),
+ depth_normalize=(0.1, 200),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064),
+)
+
+
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3']
+NuScenes_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ # resize_size=(1216, 1952), #(544, 992), #
+ # resize_size=(560, 1008),
+ # resize_size=(840, 1512),
+ resize_size=(616,1064),
+ ignore_label=-1,
+ padding=[0,0,0]),
+ # dict(type='ResizeKeepRatio',
+ # resize_size=(1120, 2016),
+ # ignore_label=-1,
+ # padding=[0,0,0],
+ # keep_gt=True),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0),
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 500,
+ ),
+ ))
+
+# DDAD_dataset=dict(
+# data = dict(
+# test=dict(
+# anno_path='DDAD/annotations/test_annotations.json',
+# pipeline=[dict(type='BGR2RGB'),
+# dict(type='KeepResizeCanoSize',
+# resize_size=(640, 1088), #(1216, 1952), #(512, 960), #
+# ignore_label=-1,
+# padding=[0, 0, 0]),
+# dict(type='ToTensor'),
+# dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+# ],
+# sample_ratio = 1.0,
+# sample_size = 80,
+# ),
+# ))
\ No newline at end of file
diff --git a/training/mono/configs/test_configs_vit/nyu.vit.dpt.raft.py b/training/mono/configs/test_configs_vit/nyu.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ea74f5c4515c3db46fcba51c645aa4f847c7bcd
--- /dev/null
+++ b/training/mono/configs/test_configs_vit/nyu.vit.dpt.raft.py
@@ -0,0 +1,64 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py',
+
+ '../_base_/datasets/nyu.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(NYU='NYU_dataset'),
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0,1),
+ depth_normalize=(0.1, 200),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 10),
+ vit_size=(616,1064),
+)
+clip_depth = True
+
+
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'rmse_log', 'log10', 'normal_mean', 'normal_rmse', 'normal_median', 'normal_a3', 'normal_a4', 'normal_a5']
+NYU_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ resize_size=(616, 1064), #(544, 992), #(480, 1216), #(480, 640), #
+ ignore_label=-1,
+ padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,
+ ),
+ ))
\ No newline at end of file
diff --git a/training/mono/configs/test_configs_vit/replica.vit.dpt.raft.py b/training/mono/configs/test_configs_vit/replica.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..6843c92ed9877e5e24b49b575f0780b81f1583b7
--- /dev/null
+++ b/training/mono/configs/test_configs_vit/replica.vit.dpt.raft.py
@@ -0,0 +1,64 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py',
+
+ '../_base_/datasets/replica.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model=dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(Replica='Replica_dataset'), # 5.6w
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, 200),# (0.3, 160),
+ # crop_size = (512, 960),
+ clip_depth_range=(0.1, 200),
+)
+
+
+
+# indoor (544, 928), outdoor: (768, 1088)
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'normal_median' , 'normal_mean', 'normal_rmse', 'normal_a1', 'normal_a2', 'normal_a3', 'normal_a4', 'normal_a5']
+Replica_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ resize_size=(616, 1064), #(544, 992), #(768, 1088), #(768, 1120), # (768, 1216), #(768, 1024), # (768, 1216), #(768, 1312), #
+ ignore_label=-1,
+ padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,
+ ),
+ ))
diff --git a/training/mono/configs/test_configs_vit/scannet.vit.dpt.raft.py b/training/mono/configs/test_configs_vit/scannet.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..6524815ef16402869c57a9df1423a8b442c7fb25
--- /dev/null
+++ b/training/mono/configs/test_configs_vit/scannet.vit.dpt.raft.py
@@ -0,0 +1,66 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py',
+
+ '../_base_/datasets/scannet.py',
+ '../_base_/datasets/scannet_all.py',
+ #'../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ #dict(ScanNet='ScanNet_dataset'),
+ dict(ScanNetAll='ScanNetAll_dataset')
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0,1),
+ depth_normalize=(0.1, 200),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064),
+)
+
+
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'rmse_log', 'log10', 'normal_mean', 'normal_rmse', 'normal_median', 'normal_a3', 'normal_a4', 'normal_a5']
+ScanNetAll_dataset=dict(
+#ScanNet_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ resize_size=(616, 1064), #(544, 992), #(480, 1216), #(480, 640), #
+ ignore_label=-1,
+ padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 500,
+ ),
+ ))
\ No newline at end of file
diff --git a/training/mono/configs/test_configs_vit_giant2/ddad.vit.dpt.raft.py b/training/mono/configs/test_configs_vit_giant2/ddad.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..e10f34d62e9c26180cac7ecdc681f8f961a3a162
--- /dev/null
+++ b/training/mono/configs/test_configs_vit_giant2/ddad.vit.dpt.raft.py
@@ -0,0 +1,94 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py',
+
+ '../_base_/datasets/ddad.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(DDAD='DDAD_dataset'),
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0,1),
+ depth_normalize=(0.1, 200),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064),
+)
+
+
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3']
+DDAD_dataset=dict(
+ data = dict(
+ test=dict(
+ anno_path='DDAD/annotations/test_annotations.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ # resize_size=(1216, 1952), #(544, 992), #
+ # resize_size=(560, 1008),
+ # resize_size=(840, 1512),
+ resize_size=(616,1064),
+ ignore_label=-1,
+ padding=[0,0,0]),
+ # dict(type='ResizeKeepRatio',
+ # resize_size=(1120, 2016),
+ # ignore_label=-1,
+ # padding=[0,0,0],
+ # keep_gt=True),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0),
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 500,
+ ),
+ ))
+
+# DDAD_dataset=dict(
+# data = dict(
+# test=dict(
+# anno_path='DDAD/annotations/test_annotations.json',
+# pipeline=[dict(type='BGR2RGB'),
+# dict(type='KeepResizeCanoSize',
+# resize_size=(640, 1088), #(1216, 1952), #(512, 960), #
+# ignore_label=-1,
+# padding=[0, 0, 0]),
+# dict(type='ToTensor'),
+# dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+# ],
+# sample_ratio = 1.0,
+# sample_size = 80,
+# ),
+# ))
\ No newline at end of file
diff --git a/training/mono/configs/test_configs_vit_giant2/diode.vit.dpt.raft.py b/training/mono/configs/test_configs_vit_giant2/diode.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf203976e9ac02fa32bd501e61908c876ec74b7c
--- /dev/null
+++ b/training/mono/configs/test_configs_vit_giant2/diode.vit.dpt.raft.py
@@ -0,0 +1,66 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py',
+
+ '../_base_/datasets/diode.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model=dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ #dict(DIODE='DIODE_dataset'),
+ #dict(DIODE_indoor='DIODE_dataset')
+ dict(DIODE_outdoor='DIODE_dataset')
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, 200),# (0.3, 160),
+ # crop_size = (512, 960),
+ clip_depth_range=(0.1, 150),
+)
+
+
+
+# indoor (544, 928), outdoor: (768, 1088)
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'normal_mean', 'normal_rmse', 'normal_a1']
+DIODE_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ resize_size=(616, 1064), #(544, 992), #(768, 1088), #(768, 1120), # (768, 1216), #(768, 1024), # (768, 1216), #(768, 1312), #
+ ignore_label=-1,
+ padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,
+ ),
+ ))
diff --git a/training/mono/configs/test_configs_vit_giant2/dsec.vit.dpt.raft.py b/training/mono/configs/test_configs_vit_giant2/dsec.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..a12a59c3aea652bd85ae036c1991355c92bff757
--- /dev/null
+++ b/training/mono/configs/test_configs_vit_giant2/dsec.vit.dpt.raft.py
@@ -0,0 +1,95 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py',
+
+ '../_base_/datasets/dsec.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(DSEC='DSEC_dataset'),
+ ],
+]
+
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0,1),
+ depth_normalize=(0.1, 200),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064),
+)
+
+
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3']
+DSEC_dataset=dict(
+ data = dict(
+ test=dict(
+ anno_path='DSEC/annotations/test_annotations.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ # resize_size=(1216, 1952), #(544, 992), #
+ # resize_size=(560, 1008),
+ # resize_size=(840, 1512),
+ resize_size=(616,1064),
+ ignore_label=-1,
+ padding=[0,0,0]),
+ # dict(type='ResizeKeepRatio',
+ # resize_size=(1120, 2016),
+ # ignore_label=-1,
+ # padding=[0,0,0],
+ # keep_gt=True),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0),
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 500,
+ ),
+ ))
+
+# DDAD_dataset=dict(
+# data = dict(
+# test=dict(
+# anno_path='DDAD/annotations/test_annotations.json',
+# pipeline=[dict(type='BGR2RGB'),
+# dict(type='KeepResizeCanoSize',
+# resize_size=(640, 1088), #(1216, 1952), #(512, 960), #
+# ignore_label=-1,
+# padding=[0, 0, 0]),
+# dict(type='ToTensor'),
+# dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+# ],
+# sample_ratio = 1.0,
+# sample_size = 80,
+# ),
+# ))
\ No newline at end of file
diff --git a/training/mono/configs/test_configs_vit_giant2/eth3d.vit.dpt.raft.py b/training/mono/configs/test_configs_vit_giant2/eth3d.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fb27193e8e6a608a7a187866455150824b4fbf8
--- /dev/null
+++ b/training/mono/configs/test_configs_vit_giant2/eth3d.vit.dpt.raft.py
@@ -0,0 +1,70 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py',
+
+ '../_base_/datasets/eth3d.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(ETH3D='ETH3D_dataset'), #447.2w
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, 200),# (0.3, 160),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064),
+)
+
+# indoor (544, 928), outdoor: (768, 1088)
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'normal_mean', 'normal_rmse', 'normal_a1']
+ETH3D_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ # resize_size=(512, 512), #(768, 1088), #(768, 1120), # (768, 1216), #(768, 1024), # (768, 1216), #(768, 1312), # (512, 512)
+ resize_size=(616,1064),
+ # resize_size=(1120, 2016),
+ ignore_label=-1,
+ padding=[0,0,0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0),
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,
+ ),
+ ))
diff --git a/training/mono/configs/test_configs_vit_giant2/ibims.vit.dpt.raft.py b/training/mono/configs/test_configs_vit_giant2/ibims.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..4523fb35a715bfb7f4c63ca93e3ea4e934eb604c
--- /dev/null
+++ b/training/mono/configs/test_configs_vit_giant2/ibims.vit.dpt.raft.py
@@ -0,0 +1,71 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py',
+
+ '../_base_/datasets/ibims.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(IBIMS='IBIMS_dataset'), #447.2w
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, 200),# (0.3, 160),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 10),
+ vit_size=(616,1064),
+)
+clip_depth = True
+
+# indoor (544, 928), outdoor: (768, 1088)
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'normal_mean', 'normal_rmse', 'normal_a3', 'normal_a4', 'normal_a5', 'normal_median']
+IBIMS_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ # resize_size=(512, 512), #(768, 1088), #(768, 1120), # (768, 1216), #(768, 1024), # (768, 1216), #(768, 1312), # (512, 512)
+ resize_size=(616,1064),
+ # resize_size=(1120, 2016),
+ ignore_label=-1,
+ padding=[0,0,0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0),
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,
+ ),
+ ))
diff --git a/training/mono/configs/test_configs_vit_giant2/kitti.vit.dpt.raft.py b/training/mono/configs/test_configs_vit_giant2/kitti.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..4807c46ff1478c956991222a7389742b50f0560f
--- /dev/null
+++ b/training/mono/configs/test_configs_vit_giant2/kitti.vit.dpt.raft.py
@@ -0,0 +1,82 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py',
+
+ '../_base_/datasets/kitti.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(KITTI='KITTI_dataset'),
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0,1),
+ depth_normalize=(0.1, 200),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 80),
+ vit_size=(616,1064),
+)
+
+clip_depth = False
+
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'rmse_log',
+ 'log10']
+KITTI_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ resize_size=(616, 1064), #(416, 1248), #(480, 1216), #(512, 1088), #(512, 1312), #(480, 1248), # #
+ ignore_label=-1,
+ padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,
+ ),
+ ))
+
+# DDAD_dataset=dict(
+# data = dict(
+# test=dict(
+# anno_path='DDAD/annotations/test_annotations.json',
+# pipeline=[dict(type='BGR2RGB'),
+# dict(type='KeepResizeCanoSize',
+# resize_size=(640, 1088), #(1216, 1952), #(512, 960), #
+# ignore_label=-1,
+# padding=[0, 0, 0]),
+# dict(type='ToTensor'),
+# dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+# ],
+# sample_ratio = 1.0,
+# sample_size = 80,
+# ),
+# ))
\ No newline at end of file
diff --git a/training/mono/configs/test_configs_vit_giant2/nuscenes.vit.dpt.raft.py b/training/mono/configs/test_configs_vit_giant2/nuscenes.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..d783a19447b03af1a62c92b0898d182c25fb641e
--- /dev/null
+++ b/training/mono/configs/test_configs_vit_giant2/nuscenes.vit.dpt.raft.py
@@ -0,0 +1,93 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py',
+
+ '../_base_/datasets/nuscenes.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(NuScenes='NuScenes_dataset'),
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0,1),
+ depth_normalize=(0.1, 200),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064),
+)
+
+
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3']
+NuScenes_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ # resize_size=(1216, 1952), #(544, 992), #
+ # resize_size=(560, 1008),
+ # resize_size=(840, 1512),
+ resize_size=(616,1064),
+ ignore_label=-1,
+ padding=[0,0,0]),
+ # dict(type='ResizeKeepRatio',
+ # resize_size=(1120, 2016),
+ # ignore_label=-1,
+ # padding=[0,0,0],
+ # keep_gt=True),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0),
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 500,
+ ),
+ ))
+
+# DDAD_dataset=dict(
+# data = dict(
+# test=dict(
+# anno_path='DDAD/annotations/test_annotations.json',
+# pipeline=[dict(type='BGR2RGB'),
+# dict(type='KeepResizeCanoSize',
+# resize_size=(640, 1088), #(1216, 1952), #(512, 960), #
+# ignore_label=-1,
+# padding=[0, 0, 0]),
+# dict(type='ToTensor'),
+# dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+# ],
+# sample_ratio = 1.0,
+# sample_size = 80,
+# ),
+# ))
\ No newline at end of file
diff --git a/training/mono/configs/test_configs_vit_giant2/nyu.vit.dpt.raft.py b/training/mono/configs/test_configs_vit_giant2/nyu.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..80f75f8a6c6a009294e8818f9d8d780e54f1f277
--- /dev/null
+++ b/training/mono/configs/test_configs_vit_giant2/nyu.vit.dpt.raft.py
@@ -0,0 +1,64 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py',
+
+ '../_base_/datasets/nyu.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(NYU='NYU_dataset'),
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0,1),
+ depth_normalize=(0.1, 200),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 10),
+ vit_size=(616,1064),
+)
+clip_depth = True
+
+
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'rmse_log', 'log10', 'normal_mean', 'normal_rmse', 'normal_median', 'normal_a3', 'normal_a4', 'normal_a5']
+NYU_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ resize_size=(616, 1064), #(544, 992), #(480, 1216), #(480, 640), #
+ ignore_label=-1,
+ padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,
+ ),
+ ))
\ No newline at end of file
diff --git a/training/mono/configs/test_configs_vit_giant2/scannet.vit.dpt.raft.py b/training/mono/configs/test_configs_vit_giant2/scannet.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c556d92cc21cb877251d378e66f1cc0475f0430
--- /dev/null
+++ b/training/mono/configs/test_configs_vit_giant2/scannet.vit.dpt.raft.py
@@ -0,0 +1,66 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py',
+
+ '../_base_/datasets/scannet.py',
+ '../_base_/datasets/scannet_all.py',
+ #'../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ #dict(ScanNet='ScanNet_dataset'),
+ dict(ScanNetAll='ScanNetAll_dataset')
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0,1),
+ depth_normalize=(0.1, 200),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064),
+)
+
+
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'rmse_log', 'log10', 'normal_mean', 'normal_rmse', 'normal_median', 'normal_a3', 'normal_a4', 'normal_a5']
+ScanNetAll_dataset=dict(
+#ScanNet_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ resize_size=(616, 1064), #(544, 992), #(480, 1216), #(480, 640), #
+ ignore_label=-1,
+ padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 500,
+ ),
+ ))
\ No newline at end of file
diff --git a/training/mono/configs/test_configs_vit_giant2/waymo.vit.dpt.raft.py b/training/mono/configs/test_configs_vit_giant2/waymo.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0a425d3f89f6215d51528a783e6a2b47f22480c
--- /dev/null
+++ b/training/mono/configs/test_configs_vit_giant2/waymo.vit.dpt.raft.py
@@ -0,0 +1,95 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py',
+
+ '../_base_/datasets/waymo.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=8,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(Waymo='Waymo_dataset'),
+ ],
+]
+
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0,1),
+ depth_normalize=(0.1, 200),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064),
+)
+
+
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3']
+Waymo_dataset=dict(
+ data = dict(
+ test=dict(
+ anno_path='Waymo/annotations/test_annotations.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ # resize_size=(1216, 1952), #(544, 992), #
+ # resize_size=(560, 1008),
+ # resize_size=(840, 1512),
+ resize_size=(616,1064),
+ ignore_label=-1,
+ padding=[0,0,0]),
+ # dict(type='ResizeKeepRatio',
+ # resize_size=(1120, 2016),
+ # ignore_label=-1,
+ # padding=[0,0,0],
+ # keep_gt=True),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0),
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 500,
+ ),
+ ))
+
+# DDAD_dataset=dict(
+# data = dict(
+# test=dict(
+# anno_path='DDAD/annotations/test_annotations.json',
+# pipeline=[dict(type='BGR2RGB'),
+# dict(type='KeepResizeCanoSize',
+# resize_size=(640, 1088), #(1216, 1952), #(512, 960), #
+# ignore_label=-1,
+# padding=[0, 0, 0]),
+# dict(type='ToTensor'),
+# dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+# ],
+# sample_ratio = 1.0,
+# sample_size = 80,
+# ),
+# ))
\ No newline at end of file
diff --git a/training/mono/configs/test_configs_vit_small/ddad.vit.dpt.raft.py b/training/mono/configs/test_configs_vit_small/ddad.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebf6bb7cc90136cfe0485d8c6171816f12d98e40
--- /dev/null
+++ b/training/mono/configs/test_configs_vit_small/ddad.vit.dpt.raft.py
@@ -0,0 +1,94 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py',
+
+ '../_base_/datasets/ddad.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=4,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(DDAD='DDAD_dataset'),
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0,1),
+ depth_normalize=(0.1, 200),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064),
+)
+
+
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3']
+DDAD_dataset=dict(
+ data = dict(
+ test=dict(
+ anno_path='DDAD/annotations/test_annotations.json',
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ # resize_size=(1216, 1952), #(544, 992), #
+ # resize_size=(560, 1008),
+ # resize_size=(840, 1512),
+ resize_size=(616,1064),
+ ignore_label=-1,
+ padding=[0,0,0]),
+ # dict(type='ResizeKeepRatio',
+ # resize_size=(1120, 2016),
+ # ignore_label=-1,
+ # padding=[0,0,0],
+ # keep_gt=True),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0),
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 500,
+ ),
+ ))
+
+# DDAD_dataset=dict(
+# data = dict(
+# test=dict(
+# anno_path='DDAD/annotations/test_annotations.json',
+# pipeline=[dict(type='BGR2RGB'),
+# dict(type='KeepResizeCanoSize',
+# resize_size=(640, 1088), #(1216, 1952), #(512, 960), #
+# ignore_label=-1,
+# padding=[0, 0, 0]),
+# dict(type='ToTensor'),
+# dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+# ],
+# sample_ratio = 1.0,
+# sample_size = 80,
+# ),
+# ))
\ No newline at end of file
diff --git a/training/mono/configs/test_configs_vit_small/diode.vit.dpt.raft.py b/training/mono/configs/test_configs_vit_small/diode.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..545911616d74712e121196d1893c383a3ec233da
--- /dev/null
+++ b/training/mono/configs/test_configs_vit_small/diode.vit.dpt.raft.py
@@ -0,0 +1,66 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py',
+
+ '../_base_/datasets/diode.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model=dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=4,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ #dict(DIODE='DIODE_dataset'),
+ #dict(DIODE_indoor='DIODE_dataset')
+ dict(DIODE_outdoor='DIODE_dataset')
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, 200),# (0.3, 160),
+ # crop_size = (512, 960),
+ clip_depth_range=(0.1, 150),
+)
+
+
+
+# indoor (544, 928), outdoor: (768, 1088)
+test_metrics = test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'normal_median' , 'normal_mean', 'normal_rmse', 'normal_a1', 'normal_a2', 'normal_a3', 'normal_a4', 'normal_a5']
+DIODE_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ resize_size=(616, 1064), #(544, 992), #(768, 1088), #(768, 1120), # (768, 1216), #(768, 1024), # (768, 1216), #(768, 1312), #
+ ignore_label=-1,
+ padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,
+ ),
+ ))
diff --git a/training/mono/configs/test_configs_vit_small/eth3d.vit.dpt.raft.py b/training/mono/configs/test_configs_vit_small/eth3d.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a9c035bc3fcdfb64657a2ef459d193f2c8c530c
--- /dev/null
+++ b/training/mono/configs/test_configs_vit_small/eth3d.vit.dpt.raft.py
@@ -0,0 +1,70 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py',
+
+ '../_base_/datasets/eth3d.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=4,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(ETH3D='ETH3D_dataset'), #447.2w
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, 200),# (0.3, 160),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064),
+)
+
+# indoor (544, 928), outdoor: (768, 1088)
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'normal_mean', 'normal_rmse', 'normal_a1']
+ETH3D_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ # resize_size=(512, 512), #(768, 1088), #(768, 1120), # (768, 1216), #(768, 1024), # (768, 1216), #(768, 1312), # (512, 512)
+ resize_size=(616,1064),
+ # resize_size=(1120, 2016),
+ ignore_label=-1,
+ padding=[0,0,0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0),
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,
+ ),
+ ))
diff --git a/training/mono/configs/test_configs_vit_small/ibims.vit.dpt.raft.py b/training/mono/configs/test_configs_vit_small/ibims.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4732570df5f65bfed63f0459a50719d44efff77
--- /dev/null
+++ b/training/mono/configs/test_configs_vit_small/ibims.vit.dpt.raft.py
@@ -0,0 +1,70 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py',
+
+ '../_base_/datasets/ibims.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=4,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(IBIMS='IBIMS_dataset'), #447.2w
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, 200),# (0.3, 160),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064),
+)
+
+# indoor (544, 928), outdoor: (768, 1088)
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'normal_mean', 'normal_rmse', 'normal_a3', 'normal_a4', 'normal_a5', 'normal_median']
+IBIMS_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ # resize_size=(512, 512), #(768, 1088), #(768, 1120), # (768, 1216), #(768, 1024), # (768, 1216), #(768, 1312), # (512, 512)
+ resize_size=(616,1064),
+ # resize_size=(1120, 2016),
+ ignore_label=-1,
+ padding=[0,0,0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0),
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,
+ ),
+ ))
diff --git a/training/mono/configs/test_configs_vit_small/kitti.vit.dpt.raft.py b/training/mono/configs/test_configs_vit_small/kitti.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..8966f5c7dcfcc791bbc192231337b8e36f509eb2
--- /dev/null
+++ b/training/mono/configs/test_configs_vit_small/kitti.vit.dpt.raft.py
@@ -0,0 +1,81 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py',
+
+ '../_base_/datasets/kitti.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=4,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(KITTI='KITTI_dataset'),
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0,1),
+ depth_normalize=(0.1, 200),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064),
+)
+
+
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'rmse_log',
+ 'log10']
+KITTI_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ resize_size=(616, 1064), #(416, 1248), #(480, 1216), #(512, 1088), #(512, 1312), #(480, 1248), # #
+ ignore_label=-1,
+ padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,
+ ),
+ ))
+
+# DDAD_dataset=dict(
+# data = dict(
+# test=dict(
+# anno_path='DDAD/annotations/test_annotations.json',
+# pipeline=[dict(type='BGR2RGB'),
+# dict(type='KeepResizeCanoSize',
+# resize_size=(640, 1088), #(1216, 1952), #(512, 960), #
+# ignore_label=-1,
+# padding=[0, 0, 0]),
+# dict(type='ToTensor'),
+# dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+# ],
+# sample_ratio = 1.0,
+# sample_size = 80,
+# ),
+# ))
\ No newline at end of file
diff --git a/training/mono/configs/test_configs_vit_small/nuscenes.vit.dpt.raft.py b/training/mono/configs/test_configs_vit_small/nuscenes.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..25f9e065b05930e6512e373b7068e1bbf9ae9d8a
--- /dev/null
+++ b/training/mono/configs/test_configs_vit_small/nuscenes.vit.dpt.raft.py
@@ -0,0 +1,93 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py',
+
+ '../_base_/datasets/nuscenes.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=4,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(NuScenes='NuScenes_dataset'),
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0,1),
+ depth_normalize=(0.1, 200),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064),
+)
+
+
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3']
+NuScenes_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ # resize_size=(1216, 1952), #(544, 992), #
+ # resize_size=(560, 1008),
+ # resize_size=(840, 1512),
+ resize_size=(616,1064),
+ ignore_label=-1,
+ padding=[0,0,0]),
+ # dict(type='ResizeKeepRatio',
+ # resize_size=(1120, 2016),
+ # ignore_label=-1,
+ # padding=[0,0,0],
+ # keep_gt=True),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0),
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,
+ ),
+ ))
+
+# DDAD_dataset=dict(
+# data = dict(
+# test=dict(
+# anno_path='DDAD/annotations/test_annotations.json',
+# pipeline=[dict(type='BGR2RGB'),
+# dict(type='KeepResizeCanoSize',
+# resize_size=(640, 1088), #(1216, 1952), #(512, 960), #
+# ignore_label=-1,
+# padding=[0, 0, 0]),
+# dict(type='ToTensor'),
+# dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+# ],
+# sample_ratio = 1.0,
+# sample_size = 80,
+# ),
+# ))
\ No newline at end of file
diff --git a/training/mono/configs/test_configs_vit_small/nyu.vit.dpt.raft.py b/training/mono/configs/test_configs_vit_small/nyu.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3f23c53a158d103bb479967c7981e81f8c9fd49
--- /dev/null
+++ b/training/mono/configs/test_configs_vit_small/nyu.vit.dpt.raft.py
@@ -0,0 +1,63 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py',
+
+ '../_base_/datasets/nyu.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=4,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(NYU='NYU_dataset'),
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0,1),
+ depth_normalize=(0.1, 200),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064),
+)
+
+
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'rmse_log', 'log10', 'normal_mean', 'normal_rmse', 'normal_median', 'normal_a3', 'normal_a4', 'normal_a5']
+NYU_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ resize_size=(616, 1064), #(544, 992), #(480, 1216), #(480, 640), #
+ ignore_label=-1,
+ padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = -1,
+ ),
+ ))
\ No newline at end of file
diff --git a/training/mono/configs/test_configs_vit_small/scannet.vit.dpt.raft.py b/training/mono/configs/test_configs_vit_small/scannet.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc5308680a13074702799c67a06160f1c007dca4
--- /dev/null
+++ b/training/mono/configs/test_configs_vit_small/scannet.vit.dpt.raft.py
@@ -0,0 +1,66 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py',
+
+ '../_base_/datasets/scannet.py',
+ '../_base_/datasets/scannet_all.py',
+ #'../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=4,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ #dict(ScanNet='ScanNet_dataset'),
+ dict(ScanNetAll='ScanNetAll_dataset')
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0,1),
+ depth_normalize=(0.1, 200),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064),
+)
+
+
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'rmse_log', 'log10', 'normal_mean', 'normal_rmse', 'normal_median', 'normal_a3', 'normal_a4', 'normal_a5']
+ScanNetAll_dataset=dict(
+#ScanNet_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ resize_size=(616, 1064), #(544, 992), #(480, 1216), #(480, 640), #
+ ignore_label=-1,
+ padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 500,
+ ),
+ ))
\ No newline at end of file
diff --git a/training/mono/configs/test_configs_vit_small/taskonomy.vit.dpt.raft.py b/training/mono/configs/test_configs_vit_small/taskonomy.vit.dpt.raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..638c945b32bc013f7d13cdf636587fe2643ece39
--- /dev/null
+++ b/training/mono/configs/test_configs_vit_small/taskonomy.vit.dpt.raft.py
@@ -0,0 +1,70 @@
+_base_=['../_base_/losses/all_losses.py',
+ '../_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py',
+
+ '../_base_/datasets/taskonomy.py',
+ '../_base_/datasets/_data_base_.py',
+
+ '../_base_/default_runtime.py',
+ '../_base_/schedules/schedule_1m.py'
+ ]
+
+import numpy as np
+
+model = dict(
+ decode_head=dict(
+ type='RAFTDepthNormalDPT5',
+ iters=4,
+ n_downsample=2,
+ detach=False,
+ )
+)
+
+# model settings
+find_unused_parameters = True
+
+
+
+# data configs, some similar data are merged together
+data_array = [
+ # group 1
+ [
+ dict(Taskonomy='Taskonomy_dataset'), #447.2w
+ ],
+]
+data_basic=dict(
+ canonical_space = dict(
+ # img_size=(540, 960),
+ focal_length=1000.0,
+ ),
+ depth_range=(0, 1),
+ depth_normalize=(0.1, 200),# (0.3, 160),
+ crop_size = (1120, 2016),
+ clip_depth_range=(0.1, 200),
+ vit_size=(616,1064),
+)
+
+# indoor (544, 928), outdoor: (768, 1088)
+test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3', 'normal_mean', 'normal_rmse', 'normal_median', 'normal_a3', 'normal_a4', 'normal_a5']
+Taskonomy_dataset=dict(
+ data = dict(
+ test=dict(
+ pipeline=[dict(type='BGR2RGB'),
+ dict(type='LabelScaleCononical'),
+ dict(type='ResizeKeepRatio',
+ # resize_size=(512, 512), #(768, 1088), #(768, 1120), # (768, 1216), #(768, 1024), # (768, 1216), #(768, 1312), # (512, 512)
+ resize_size=(616,1064),
+ # resize_size=(1120, 2016),
+ ignore_label=-1,
+ padding=[0,0,0]),
+ # dict(type='RandomCrop',
+ # crop_size=(0,0),
+ # crop_type='center',
+ # ignore_label=-1,
+ # padding=[0,0,0]),
+ dict(type='ToTensor'),
+ dict(type='Normalize', mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
+ ],
+ sample_ratio = 1.0,
+ sample_size = 500,
+ ),
+ ))
diff --git a/training/mono/datasets/__base_dataset__.py b/training/mono/datasets/__base_dataset__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a138759c4a022fe403a4b15fc80e436a71ed49b1
--- /dev/null
+++ b/training/mono/datasets/__base_dataset__.py
@@ -0,0 +1,586 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+import mono.utils.transform as img_transform
+import copy
+from mono.utils.comm import get_func
+import pickle
+import logging
+import multiprocessing as mp
+import ctypes
+"""
+Dataset annotations are saved in a Json file. All data, including rgb, depth, pose, and so on, captured within the same frame are saved in the same dict.
+All frames are organized in a list. In each frame, it may contains the some or all of following data format.
+
+# Annotations for the current central RGB/depth cameras.
+
+'rgb': rgb image in the current frame.
+'depth': depth map in the current frame.
+'sem': semantic mask in the current frame.
+'cam_in': camera intrinsic parameters of the current rgb camera.
+'cam_ex': camera extrinsic parameters of the current rgb camera.
+'cam_ex_path': path to the extrinsic parameters.
+'pose': pose in current frame.
+'timestamp_rgb': time stamp of current rgb image.
+
+# Annotations for the left hand RGB/depth cameras.
+
+'rgb_l': rgb image of the left hand camera in the current frame.
+'depth_l': depth map of the left hand camera in the current frame.
+'sem_l': semantic mask of the left hand camera in the current frame.
+'cam_in_l': camera intrinsic parameters of the left hand rgb camera in the current frame.
+'cam_ex_l': camera extrinsic parameters of the left hand rgb camera in the current frame.
+'cam_ex_path': path to the extrinsic parameters.
+'pose_l': pose of the left hand camera in the incurrent frame.
+'timestamp_rgb_l': time stamp of the rgb img captured by the left hand camera.
+
+# Annotations for the right RGB/depth cameras, which is on the left hand of the current central cameras.
+
+'rgb_r': rgb image of the right hand camera in the current frame.
+'depth_r': depth map of the right hand camera in the current frame.
+'sem_r': semantic mask of the right hand camera in the current frame.
+'cam_in_r': camera intrinsic parameters of the right hand rgb camera in the current frame.
+'cam_ex_r': camera extrinsic parameters of the right hand rgb camera in the current frame.
+'cam_ex_path_r': path to the extrinsic parameters.
+'pose_r': pose of the right hand camera in the incurrent frame.
+'timestamp_rgb_r': time stamp of the rgb img captured by the right hand camera.
+
+# Annotations for the central RGB/depth cameras in the last frame.
+
+'rgb_pre': rgb image of the central camera in the last frame.
+'depth_pre': depth map of the central camera in the last frame.
+'sem_pre': semantic mask of the central camera in the last frame.
+'cam_in_pre': camera intrinsic parameters of the central rgb camera in the last frame.
+'cam_ex_pre': camera extrinsic parameters of the central rgb camera in the last frame.
+'cam_ex_path_pre': path to the extrinsic parameters.
+'pose_pre': pose of the central camera in the last frame.
+'timestamp_rgb_pre': time stamp of the rgb img captured by the central camera.
+
+# Annotations for the central RGB/depth cameras in the next frame.
+
+'rgb_next': rgb image of the central camera in the next frame.
+'depth_next': depth map of the central camera in the next frame.
+'sem_next': semantic mask of the central camera in the next frame.
+'cam_in_next': camera intrinsic parameters of the central rgb camera in the next frame.
+'cam_ex_next': camera extrinsic parameters of the central rgb camera in the next frame.
+'cam_ex_path_next': path to the extrinsic parameters.
+'pose_next': pose of the central camera in the next frame.
+'timestamp_rgb_next': time stamp of the rgb img captured by the central camera.
+"""
+
+class BaseDataset(Dataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(BaseDataset, self).__init__()
+ self.cfg = cfg
+ self.phase = phase
+ self.db_info = kwargs['db_info']
+
+ # root dir for data
+ self.data_root = os.path.join(self.db_info['db_root'], self.db_info['data_root'])
+ # depth/disp data root
+ disp_root = self.db_info['disp_root'] if 'disp_root' in self.db_info else None
+ self.disp_root = os.path.join(self.db_info['db_root'], disp_root) if disp_root is not None else None
+ depth_root = self.db_info['depth_root'] if 'depth_root' in self.db_info else None
+ self.depth_root = os.path.join(self.db_info['db_root'], depth_root) if depth_root is not None \
+ else self.data_root
+ # meta data root
+ meta_data_root = self.db_info['meta_data_root'] if 'meta_data_root' in self.db_info else None
+ self.meta_data_root = os.path.join(self.db_info['db_root'], meta_data_root) if meta_data_root is not None \
+ else None
+ # semantic segmentation labels root
+ sem_root = self.db_info['semantic_root'] if 'semantic_root' in self.db_info else None
+ self.sem_root = os.path.join(self.db_info['db_root'], sem_root) if sem_root is not None \
+ else None
+ # depth valid mask labels root
+ depth_mask_root = self.db_info['depth_mask_root'] if 'depth_mask_root' in self.db_info else None
+ self.depth_mask_root = os.path.join(self.db_info['db_root'], depth_mask_root) if depth_mask_root is not None \
+ else None
+ # surface normal labels root
+ norm_root = self.db_info['normal_root'] if 'normal_root' in self.db_info else None
+ self.norm_root = os.path.join(self.db_info['db_root'], norm_root) if norm_root is not None \
+ else None
+ # data annotations path
+ self.data_annos_path = os.path.join(self.db_info['db_root'], self.db_info['%s_annotations_path' % phase])
+
+ # load annotations
+ self.data_info = self.load_annotations()
+ whole_data_size = len(self.data_info['files'])
+
+ # sample a subset for training/validation/testing
+ # such method is deprecated, each training may get different sample list
+
+ cfg_sample_ratio = cfg.data[phase].sample_ratio
+ cfg_sample_size = int(cfg.data[phase].sample_size)
+ self.sample_size = int(whole_data_size * cfg_sample_ratio) if cfg_sample_size == -1 \
+ else (cfg_sample_size if cfg_sample_size < whole_data_size else whole_data_size)
+ random.seed(100) # set the random seed
+ sample_list_of_whole_data = random.sample(list(range(whole_data_size)), self.sample_size)
+
+ self.data_size = self.sample_size
+ self.annotations = {'files': [self.data_info['files'][i] for i in sample_list_of_whole_data]}
+ self.sample_list = list(range(self.data_size))
+
+ # config transforms for the input and label
+ self.transforms_cfg = cfg.data[phase]['pipeline']
+ self.transforms_lib = 'mono.utils.transform.'
+
+ self.img_file_type = ['.png', '.jpg', '.jpeg', '.bmp', '.tif']
+ self.np_file_type = ['.npz', '.npy']
+
+ # update canonical sparce information
+ self.data_basic = copy.deepcopy(kwargs)
+ canonical = self.data_basic.pop('canonical_space')
+ self.data_basic.update(canonical)
+ self.disp_scale = 10.0
+ self.depth_range = kwargs['depth_range'] # predefined depth range for the network
+ self.clip_depth_range = kwargs['clip_depth_range'] # predefined depth range for data processing
+ self.depth_normalize = kwargs['depth_normalize']
+
+ self.img_transforms = img_transform.Compose(self.build_data_transforms())
+ self.EPS = 1e-6
+
+ # self.tmpl_info = ['rgb_sr', 'rgb_pre', 'rgb_next']
+ # self.tgt2ref_pose_lookup = {'rgb_sr': 'cam_ex', 'rgb_pre': 'pose_pre', 'rgb_next': 'pose_next'}
+
+ # dataset info
+ self.data_name = cfg.data_name
+ self.data_type = cfg.data_type # there are mainly four types, i.e. ['rel', 'sfm', 'stereo', 'lidar']
+ self.logger = logging.getLogger()
+ self.logger.info(f'{self.data_name} in {self.phase} whole data size: {whole_data_size}')
+
+ # random crop size for training
+ crop_size = kwargs['crop_size']
+ shared_array_base = mp.Array(ctypes.c_int32, 2)
+ shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
+ shared_array[0] = crop_size[0]
+ shared_array[1] = crop_size[1]
+ # self.random_crop_size = torch.from_numpy(np.array([0,0])) #torch.from_numpy(shared_array)
+ self.random_crop_size = torch.from_numpy(shared_array)
+
+
+ def __name__(self):
+ return self.data_name
+
+ def __len__(self):
+ return self.data_size
+
+ def load_annotations(self):
+ if not os.path.exists(self.data_annos_path):
+ raise RuntimeError(f'Cannot find {self.data_annos_path} annotations.')
+
+ with open(self.data_annos_path, 'r') as f:
+ annos = json.load(f)
+ return annos
+
+ def build_data_transforms(self):
+ transforms_list = []
+ for transform in self.transforms_cfg:
+ args = copy.deepcopy(transform)
+ # insert the canonical space configs
+ args.update(self.data_basic)
+
+ obj_name = args.pop('type')
+ obj_path = self.transforms_lib + obj_name
+ obj_cls = get_func(obj_path)
+
+ obj = obj_cls(**args)
+ transforms_list.append(obj)
+ return transforms_list
+
+
+ def load_data(self, path: str, is_rgb_img: bool=False):
+ if not os.path.exists(path):
+ self.logger.info(f'>>>>{path} does not exist.')
+ # raise RuntimeError(f'{path} does not exist.')
+
+ data_type = os.path.splitext(path)[-1]
+ if data_type in self.img_file_type:
+ if is_rgb_img:
+ data = cv2.imread(path)
+ else:
+ data = cv2.imread(path, -1)
+ elif data_type in self.np_file_type:
+ data = np.load(path)
+ else:
+ raise RuntimeError(f'{data_type} is not supported in current version.')
+
+ try:
+ return data.squeeze()
+ except:
+ temp = 1
+ raise RuntimeError(f'{path} is not successfully loaded.')
+
+ def __getitem__(self, idx: int) -> dict:
+ if self.phase == 'test':
+ return self.get_data_for_test(idx)
+ else:
+ return self.get_data_for_trainval(idx)
+
+ def get_data_for_trainval(self, idx: int):
+ anno = self.annotations['files'][idx]
+ meta_data = self.load_meta_data(anno)
+
+ data_path = self.load_data_path(meta_data)
+ data_batch = self.load_batch(meta_data, data_path)
+ # if data_path['sem_path'] is not None:
+ # print(self.data_name)
+
+ curr_rgb, curr_depth, curr_normal, curr_sem, curr_cam_model = data_batch['curr_rgb'], data_batch['curr_depth'], data_batch['curr_normal'], data_batch['curr_sem'], data_batch['curr_cam_model']
+ #curr_stereo_depth = data_batch['curr_stereo_depth']
+
+ # A patch for stereo depth dataloader (no need to modify specific datasets)
+ if 'curr_stereo_depth' in data_batch.keys():
+ curr_stereo_depth = data_batch['curr_stereo_depth']
+ else:
+ curr_stereo_depth = self.load_stereo_depth_label(None, H=curr_rgb.shape[0], W=curr_rgb.shape[1])
+
+ curr_intrinsic = meta_data['cam_in']
+ # data augmentation
+ transform_paras = dict(random_crop_size = self.random_crop_size) # dict()
+ assert curr_rgb.shape[:2] == curr_depth.shape == curr_normal.shape[:2] == curr_sem.shape
+ rgbs, depths, intrinsics, cam_models, normals, other_labels, transform_paras = self.img_transforms(
+ images=[curr_rgb, ],
+ labels=[curr_depth, ],
+ intrinsics=[curr_intrinsic,],
+ cam_models=[curr_cam_model, ],
+ normals = [curr_normal, ],
+ other_labels=[curr_sem, curr_stereo_depth],
+ transform_paras=transform_paras)
+ # process sky masks
+ sem_mask = other_labels[0].int()
+ # clip depth map
+ depth_out = self.normalize_depth(depths[0])
+ # set the depth of sky region to the invalid
+ depth_out[sem_mask==142] = -1 # self.depth_normalize[1] - 1e-6
+ # get inverse depth
+ inv_depth = self.depth2invdepth(depth_out, sem_mask==142)
+ filename = os.path.basename(meta_data['rgb'])[:-4] + '.jpg'
+ curr_intrinsic_mat = self.intrinsics_list2mat(intrinsics[0])
+ cam_models_stacks = [
+ torch.nn.functional.interpolate(cam_models[0][None, :, :, :], size=(cam_models[0].shape[1]//i, cam_models[0].shape[2]//i), mode='bilinear', align_corners=False).squeeze()
+ for i in [2, 4, 8, 16, 32]
+ ]
+
+ # stereo_depth
+ if 'label_scale_factor' not in transform_paras.keys():
+ transform_paras['label_scale_factor'] = 1
+ stereo_depth_pre_trans = other_labels[1] * (other_labels[1] > 0.3) * (other_labels[1] < 200)
+ stereo_depth = stereo_depth_pre_trans * transform_paras['label_scale_factor']
+ stereo_depth = self.normalize_depth(stereo_depth)
+
+ pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
+ data = dict(input=rgbs[0],
+ target=depth_out,
+ intrinsic=curr_intrinsic_mat,
+ filename=filename,
+ dataset=self.data_name,
+ cam_model=cam_models_stacks,
+ pad=torch.tensor(pad),
+ data_type=[self.data_type, ],
+ sem_mask=sem_mask.int(),
+ stereo_depth= stereo_depth,
+ normal=normals[0],
+ inv_depth=inv_depth,
+ scale=transform_paras['label_scale_factor'])
+ return data
+
+ def get_data_for_test(self, idx: int):
+ anno = self.annotations['files'][idx]
+ meta_data = self.load_meta_data(anno)
+ data_path = self.load_data_path(meta_data)
+ data_batch = self.load_batch(meta_data, data_path)
+ # load data
+ curr_rgb, curr_depth, curr_normal, curr_cam_model = data_batch['curr_rgb'], data_batch['curr_depth'], data_batch['curr_normal'], data_batch['curr_cam_model']
+ ori_curr_intrinsic = meta_data['cam_in']
+
+ # get crop size
+ transform_paras = dict()
+ rgbs, depths, intrinsics, cam_models, _, other_labels, transform_paras = self.img_transforms(
+ images=[curr_rgb,], #+ tmpl_rgbs,
+ labels=[curr_depth, ],
+ intrinsics=[ori_curr_intrinsic, ], # * (len(tmpl_rgbs) + 1),
+ cam_models=[curr_cam_model, ],
+ transform_paras=transform_paras)
+ # depth in original size and orignial metric***
+ depth_out = self.clip_depth(curr_depth) * self.depth_range[1] # self.clip_depth(depths[0]) #
+ inv_depth = self.depth2invdepth(depth_out, np.zeros_like(depth_out, dtype=np.bool))
+ filename = os.path.basename(meta_data['rgb'])[:-4] + '.jpg'
+ curr_intrinsic_mat = self.intrinsics_list2mat(intrinsics[0])
+ ori_curr_intrinsic_mat = self.intrinsics_list2mat(ori_curr_intrinsic)
+
+ pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
+ scale_ratio = transform_paras['label_scale_factor'] if 'label_scale_factor' in transform_paras else 1.0
+ cam_models_stacks = [
+ torch.nn.functional.interpolate(cam_models[0][None, :, :, :], size=(cam_models[0].shape[1]//i, cam_models[0].shape[2]//i), mode='bilinear', align_corners=False).squeeze()
+ for i in [2, 4, 8, 16, 32]
+ ]
+ raw_rgb = torch.from_numpy(curr_rgb)
+ curr_normal = torch.from_numpy(curr_normal.transpose((2,0,1)))
+
+
+ data = dict(input=rgbs[0],
+ target=depth_out,
+ intrinsic=curr_intrinsic_mat,
+ filename=filename,
+ dataset=self.data_name,
+ cam_model=cam_models_stacks,
+ pad=pad,
+ scale=scale_ratio,
+ raw_rgb=raw_rgb,
+ sample_id=idx,
+ data_path=meta_data['rgb'],
+ inv_depth=inv_depth,
+ normal=curr_normal,
+ )
+ return data
+
+ def load_data_path(self, meta_data):
+ curr_rgb_path = os.path.join(self.data_root, meta_data['rgb'])
+ curr_depth_path = os.path.join(self.depth_root, meta_data['depth'])
+ curr_sem_path = os.path.join(self.sem_root, meta_data['sem']) \
+ if self.sem_root is not None and ('sem' in meta_data) and (meta_data['sem'] is not None) \
+ else None
+ # matterport3d separates xyz into three images
+ if ('normal' in meta_data) and (meta_data['normal'] is not None) and (self.norm_root is not None):
+ if isinstance(meta_data['normal'], dict):
+ curr_norm_path = {}
+ for k,v in meta_data['normal'].items():
+ curr_norm_path[k] = os.path.join(self.norm_root, v)
+ else:
+ curr_norm_path = os.path.join(self.norm_root, meta_data['normal'])
+ else:
+ curr_norm_path = None
+ curr_depth_mask_path = os.path.join(self.depth_mask_root, meta_data['depth_mask']) \
+ if self.depth_mask_root is not None and ('depth_mask' in meta_data) and (meta_data['depth_mask'] is not None) \
+ else None
+
+ if ('disp' in meta_data) and (meta_data['disp'] is not None) and (self.disp_root is not None):
+ if isinstance(meta_data['disp'], dict):
+ curr_disp_path = {}
+ for k,v in meta_data['disp'].items():
+ curr_disp_path[k] = os.path.join(self.disp_root, v)
+ else:
+ curr_disp_path = os.path.join(self.disp_root, meta_data['disp'])
+ else:
+ curr_disp_path = None
+
+ data_path=dict(
+ rgb_path=curr_rgb_path,
+ depth_path=curr_depth_path,
+ sem_path=curr_sem_path,
+ normal_path=curr_norm_path,
+ disp_path=curr_disp_path,
+ depth_mask_path=curr_depth_mask_path,
+ )
+ return data_path
+
+ def load_batch(self, meta_data, data_path):
+ curr_intrinsic = meta_data['cam_in']
+ # load rgb/depth
+ curr_rgb, curr_depth = self.load_rgb_depth(data_path['rgb_path'], data_path['depth_path'])
+ # get semantic labels
+ curr_sem = self.load_sem_label(data_path['sem_path'], curr_depth)
+ # create camera model
+ curr_cam_model = self.create_cam_model(curr_rgb.shape[0], curr_rgb.shape[1], curr_intrinsic)
+ # get normal labels
+ curr_normal = self.load_norm_label(data_path['normal_path'], H=curr_rgb.shape[0], W=curr_rgb.shape[1])
+ # get depth mask
+ depth_mask = self.load_depth_valid_mask(data_path['depth_mask_path'])
+ curr_depth[~depth_mask] = -1
+ # get stereo depth
+ curr_stereo_depth = self.load_stereo_depth_label(data_path['disp_path'], H=curr_rgb.shape[0], W=curr_rgb.shape[1])
+
+ data_batch = dict(
+ curr_rgb = curr_rgb,
+ curr_depth = curr_depth,
+ curr_sem = curr_sem,
+ curr_normal = curr_normal,
+ curr_cam_model=curr_cam_model,
+ curr_stereo_depth=curr_stereo_depth,
+ )
+ return data_batch
+
+
+ def clip_depth(self, depth: np.array) -> np.array:
+ depth[(depth>self.clip_depth_range[1]) | (depth np.array:
+ depth /= self.depth_range[1]
+ depth[depth np.array:
+ """
+ Encode the camera model (focal length and principle point) to a 4-channel map.
+ """
+ fx, fy, u0, v0 = intrinsics
+ f = (fx + fy) / 2.0
+ # principle point location
+ x_row = np.arange(0, W).astype(np.float32)
+ x_row_center_norm = (x_row - u0) / W
+ x_center = np.tile(x_row_center_norm, (H, 1)) # [H, W]
+
+ y_col = np.arange(0, H).astype(np.float32)
+ y_col_center_norm = (y_col - v0) / H
+ y_center = np.tile(y_col_center_norm, (W, 1)).T
+
+ # FoV
+ fov_x = np.arctan(x_center / (f / W))
+ fov_y = np.arctan(y_center/ (f / H))
+
+ cam_model = np.stack([x_center, y_center, fov_x, fov_y], axis=2)
+ return cam_model
+
+ def check_data(self, data_dict : dict):
+ for k, v in data_dict.items():
+ if v is None:
+ # print(f'{self.data_name}, {k} cannot be read!')
+ self.logger.info(f'{self.data_name}, {k} cannot be read!')
+
+ def intrinsics_list2mat(self, intrinsics: torch.tensor) -> torch.tensor:
+ """
+ Create camera intrinsic matrix.
+ Args:
+ intrinsics (torch.tensor, [4,]): list of camera intrinsic parameters.
+ returns:
+ intrinsics_mat (torch.tensor, [3x3]): camera intrinsic parameters matrix.
+ """
+ intrinsics_mat = torch.zeros((3,3)).float()
+ intrinsics_mat[0, 0] = intrinsics[0]
+ intrinsics_mat[1, 1] = intrinsics[1]
+ intrinsics_mat[0, 2] = intrinsics[2]
+ intrinsics_mat[1, 2] = intrinsics[3]
+ intrinsics_mat[2, 2] = 1.0
+ return intrinsics_mat
+
+ # def load_tmpl_image(self, curr_rgb: np.array, meta_data: dict) -> dict:
+ # """
+ # Load consecutive RGB frames.
+ # Args:
+ # anno: the annotation for this group.
+ # curr_rgb: rgb image of the current frame.
+ # meta_data: meta data information.
+ # Returns:
+ # tmpl_annos: temporal rgbs.
+ # """
+ # w_tmpl = False
+
+ # tmpl_list = []
+ # # organize temporal annotations
+ # for i in self.tmpl_info:
+ # if (i in meta_data) and (meta_data[i] is not None) and os.path.exists(os.path.join(self.data_root, meta_data[i])):
+ # tmpl_list.append(os.path.join(self.data_root, meta_data[i]))
+
+ # if len(tmpl_list) == 0:
+ # rgb_tmpl = curr_rgb.copy()
+ # else:
+ # id = np.random.randint(len(tmpl_list))
+ # rgb_tmpl = self.load_data(tmpl_list[id], is_rgb_img=True)
+ # w_tmpl = True
+
+ # tmpl_annos = dict(
+ # tmpl_rgb_list = [rgb_tmpl,],
+ # w_tmpl = w_tmpl
+ # )
+ # return tmpl_annos
+
+ def load_meta_data(self, anno: dict) -> dict:
+ """
+ Load meta data information.
+ """
+ if self.meta_data_root is not None and ('meta_data' in anno or 'meta' in anno):
+ meta_data_path = os.path.join(self.meta_data_root, anno['meta_data']) if 'meta_data' in anno else os.path.join(self.meta_data_root, anno['meta'])
+ with open(meta_data_path, 'rb') as f:
+ meta_data = pickle.load(f)
+ meta_data.update(anno)
+ else:
+ meta_data = anno
+ return meta_data
+
+ def load_rgb_depth(self, rgb_path: str, depth_path: str):
+ """
+ Load the rgb and depth map with the paths.
+ """
+ rgb = self.load_data(rgb_path, is_rgb_img=True)
+ if rgb is None:
+ self.logger.info(f'>>>>{rgb_path} has errors.')
+
+ depth = self.load_data(depth_path)
+ if depth is None:
+ self.logger.info(f'{depth_path} has errors.')
+
+ # self.check_data(dict(
+ # rgb_path=rgb,
+ # depth_path=depth,
+ # ))
+ depth = depth.astype(np.float)
+ # if depth.shape != rgb.shape[:2]:
+ # print(f'no-equal in {self.data_name}')
+ # depth = cv2.resize(depth, rgb.shape[::-1][1:])
+
+ depth = self.process_depth(depth, rgb)
+ return rgb, depth
+
+ def load_sem_label(self, sem_path, depth=None, sky_id=142) -> np.array:
+ H, W = depth.shape
+ # if sem_path is not None:
+ # print(self.data_name)
+ sem_label = cv2.imread(sem_path, 0) if sem_path is not None \
+ else np.ones((H, W), dtype=np.int) * -1
+ if sem_label is None:
+ sem_label = np.ones((H, W), dtype=np.int) * -1
+ # set dtype to int before
+ sem_label = sem_label.astype(np.int)
+ sem_label[sem_label==255] = -1
+
+ # mask invalid sky region
+ mask_depth_valid = depth > 1e-8
+ invalid_sky_region = (sem_label==142) & (mask_depth_valid)
+ if self.data_type in ['lidar', 'sfm', 'denselidar', 'denselidar_nometric']:
+ sem_label[invalid_sky_region] = -1
+ return sem_label
+
+ def load_depth_valid_mask(self, depth_mask_path, depth=None) -> np.array:
+ if depth_mask_path is None:
+ return np.ones_like(depth, dtype=np.bool)
+ data_type = os.path.splitext(depth_mask_path)[-1]
+ if data_type in self.img_file_type:
+ data = cv2.imread(depth_mask_path, -1)
+ elif data_type in self.np_file_type:
+ data = np.load(depth_mask_path)
+ else:
+ raise RuntimeError(f'{data_type} is not supported in current version.')
+ data = data.astype(np.bool)
+ return data
+
+ def load_norm_label(self, norm_path, H, W):
+ norm_gt = np.zeros((H, W, 3)).astype(np.float32)
+ return norm_gt
+
+ def load_stereo_depth_label(self, disp_path, H, W):
+ stereo_depth_gt = np.zeros((H, W, 1)).astype(np.float32)
+ return stereo_depth_gt
+
+ def depth2invdepth(self, depth, sky_mask):
+ inv_depth = 1.0 / depth * self.disp_scale
+ inv_depth[depth<1e-6] = -1.0
+ inv_depth[inv_depth < 0] = -1.0
+ inv_depth[sky_mask] = 0
+ return inv_depth
+
+
+ def set_random_crop_size(self, random_crop_size):
+ self.random_crop_size[0] = random_crop_size[0]
+ self.random_crop_size[1] = random_crop_size[1]
diff --git a/training/mono/datasets/__init__.py b/training/mono/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3a5259334828e21090987a151d2ff83fc0d2fc3
--- /dev/null
+++ b/training/mono/datasets/__init__.py
@@ -0,0 +1,38 @@
+from .__base_dataset__ import BaseDataset
+from .ddad_dataset import DDADDataset
+from .mapillary_psd_dataset import MapillaryPSDDataset
+from .argovers2_dataset import Argovers2Dataset
+from .cityscapes_dataset import CityscapesDataset
+from .drivingstereo_dataset import DrivingStereoDataset
+from .dsec_dataset import DSECDataset
+from .lyft_dataset import LyftDataset
+from .diml_dataset import DIMLDataset
+from .any_dataset import AnyDataset
+from .nyu_dataset import NYUDataset
+from .scannet_dataset import ScanNetDataset
+from .diode_dataset import DIODEDataset
+from .kitti_dataset import KITTIDataset
+from .pandaset_dataset import PandasetDataset
+from .taskonomy_dataset import TaskonomyDataset
+from .uasol_dataset import UASOLDataset
+from .nuscenes_dataset import NuScenesDataset
+from .eth3d_dataset import ETH3DDataset
+from .waymo_dataset import WaymoDataset
+from .ibims_dataset import IBIMSDataset
+
+from .replica_dataset import ReplicaDataset
+from .hm3d_dataset import HM3DDataset
+from .matterport3d_dataset import Matterport3DDataset
+from .virtualkitti_dataset import VKITTIDataset
+from .blendedmvg_omni_dataset import BlendedMVGOmniDataset
+from .hypersim_dataset import HypersimDataset
+
+__all__ = ['BaseDataset', 'DDADDataset', 'MapillaryPSDDataset',
+'Argovers2Dataset', 'CityscapesDataset', 'DrivingStereoDataset', 'DSECDataset', 'LyftDataset', 'DIMLDataset', 'AnyDataset',
+'NYUDataset', 'ScanNetDataset', 'DIODEDataset', 'KITTIDataset', 'PandasetDataset', 'SUNRGBDDataset',
+'TaskonomyDataset',
+'UASOLDataset', 'NuScenesDataset',
+'G8V1Dataset', 'ETH3DDataset', 'WaymoDataset',
+'IBIMSDataset',
+'ReplicaDataset', 'HM3DDataset', 'Matterport3DDataset', 'VKITTIDataset',
+'BlendedMVGOmniDataset']
diff --git a/training/mono/datasets/any_dataset.py b/training/mono/datasets/any_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..63b3f82e48afb7fec3b4c2592df72bb24287de5f
--- /dev/null
+++ b/training/mono/datasets/any_dataset.py
@@ -0,0 +1,152 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+import copy
+from .__base_dataset__ import BaseDataset
+import mono.utils.transform as img_transform
+
+
+class AnyDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(AnyDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+
+ self.cfg = cfg
+ self.phase = phase
+ self.mldb_info = kwargs['mldb_info']
+
+ # root dir for data
+ self.data_root = os.path.join(self.mldb_info['mldb_root'], self.mldb_info['data_root'])
+ # depth/disp data root
+ disp_root = self.mldb_info['disp_root'] if 'disp_root' in self.mldb_info else None
+ self.disp_root = os.path.join(self.mldb_info['mldb_root'], disp_root) if disp_root is not None else None
+ depth_root = self.mldb_info['depth_root'] if 'depth_root' in self.mldb_info else None
+ self.depth_root = os.path.join(self.mldb_info['mldb_root'], depth_root) if depth_root is not None \
+ else self.data_root
+ # meta data root
+ meta_data_root = self.mldb_info['meta_data_root'] if 'meta_data_root' in self.mldb_info else None
+ self.meta_data_root = os.path.join(self.mldb_info['mldb_root'], meta_data_root) if meta_data_root is not None \
+ else None
+ # semantic segmentation labels root
+ sem_root = self.mldb_info['semantic_root'] if 'semantic_root' in self.mldb_info else None
+ self.sem_root = os.path.join(self.mldb_info['mldb_root'], sem_root) if sem_root is not None \
+ else None
+
+ # data annotations path
+ self.data_annos_path = '/yvan1/data/NuScenes/NuScenes/annotations/train_ring_annotations.json' # fill this
+
+ # load annotations
+ annotations = self.load_annotations()
+ whole_data_size = len(annotations['files'])
+
+ cfg_sample_ratio = cfg.data[phase].sample_ratio
+ cfg_sample_size = int(cfg.data[phase].sample_size)
+ self.sample_size = int(whole_data_size * cfg_sample_ratio) if cfg_sample_size == -1 \
+ else (cfg_sample_size if cfg_sample_size < whole_data_size else whole_data_size)
+ sample_list_of_whole_data = list(range(whole_data_size))[:self.sample_size]
+ self.data_size = self.sample_size
+ sample_list_of_whole_data = random.sample(list(range(whole_data_size)), whole_data_size)
+ self.annotations = {'files': [annotations['files'][i] for i in sample_list_of_whole_data]}
+ self.sample_list = list(range(self.data_size))
+
+ # config transforms for the input and label
+ self.transforms_cfg = cfg.data[phase]['pipeline']
+ self.transforms_lib = 'mono.utils.transform.'
+
+ self.img_file_type = ['.png', '.jpg', '.jpeg', '.bmp', '.tif']
+ self.np_file_type = ['.npz', '.npy']
+
+ # update canonical sparce information
+ self.data_basic = copy.deepcopy(kwargs)
+ canonical = self.data_basic.pop('canonical_space')
+ self.data_basic.update(canonical)
+ self.depth_range = kwargs['depth_range'] # predefined depth range for the network
+ self.clip_depth_range = kwargs['clip_depth_range'] # predefined depth range for data processing
+ self.depth_normalize = kwargs['depth_normalize']
+
+ self.img_transforms = img_transform.Compose(self.build_data_transforms())
+ self.EPS = 1e-8
+
+ self.tmpl_info = ['rgb_sr', 'rgb_pre', 'rgb_next']
+
+ # dataset info
+ self.data_name = cfg.data_name
+ self.data_type = cfg.data_type # there are mainly four types, i.e. ['rel', 'sfm', 'stereo', 'lidar']
+
+ def __getitem__(self, idx: int) -> dict:
+ return self.get_data_for_test(idx)
+
+ def get_data_for_test(self, idx: int):
+ # basic info
+ anno = self.annotations['files'][idx]
+ curr_rgb_path = os.path.join(self.data_root, anno['CAM_FRONT_RIGHT']['rgb']) # Lyft: CAM_FRONT_LEFT
+ curr_depth_path = os.path.join(self.depth_root, anno['CAM_FRONT_RIGHT']['depth'])
+ meta_data = self.load_meta_data(anno['CAM_FRONT_RIGHT'])
+ ori_curr_intrinsic = meta_data['cam_in']
+
+ curr_rgb, curr_depth = self.load_rgb_depth(curr_rgb_path, curr_depth_path)
+ ori_h, ori_w, _ = curr_rgb.shape
+ # create camera model
+ curr_cam_model = self.create_cam_model(curr_rgb.shape[0], curr_rgb.shape[1], ori_curr_intrinsic)
+ # load tmpl rgb info
+ # tmpl_annos = self.load_tmpl_annos(anno, curr_rgb, meta_data)
+ # tmpl_rgb = tmpl_annos['tmpl_rgb_list'] # list of reference rgbs
+
+ transform_paras = dict()
+ rgbs, depths, intrinsics, cam_models, other_labels, transform_paras = self.img_transforms(
+ images=[curr_rgb, ],
+ labels=[curr_depth, ],
+ intrinsics=[ori_curr_intrinsic,],
+ cam_models=[curr_cam_model, ],
+ transform_paras=transform_paras)
+ # depth in augmented size
+ # depth_out = self.clip_depth(depths[0])
+ # depth in original size
+ #depth_out = self.clip_depth(curr_depth)
+ depth_out = curr_depth
+
+ filename = os.path.basename(curr_rgb_path)
+ curr_intrinsic_mat = self.intrinsics_list2mat(intrinsics[0])
+
+ pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
+ scale_ratio = transform_paras['label_scale_factor'] if 'label_scale_factor' in transform_paras else 1.0
+ cam_models_stacks = [
+ torch.nn.functional.interpolate(cam_models[0][None, :, :, :], size=(cam_models[0].shape[1]//i, cam_models[0].shape[2]//i), mode='bilinear', align_corners=False).squeeze()
+ for i in [2, 4, 8, 16, 32]
+ ]
+ raw_rgb = torch.from_numpy(curr_rgb)
+ data = dict(input=rgbs[0],
+ target=depth_out,
+ intrinsic=curr_intrinsic_mat,
+ filename=filename,
+ dataset=self.data_name,
+ cam_model=cam_models_stacks,
+ # ref_input=rgbs[1:],
+ # tmpl_flg=tmpl_annos['w_tmpl'],
+ pad=pad,
+ scale=scale_ratio,
+ raw_rgb=raw_rgb)
+ return data
+
+
+ def process_depth(self, depth):
+ depth[depth>65500] = 0
+ depth /= 200.0
+ return depth
+
+
+
+if __name__ == '__main__':
+ from mmcv.utils import Config
+ cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
+ dataset_i = ApolloscapeDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
+ print(dataset_i)
+
\ No newline at end of file
diff --git a/training/mono/datasets/argovers2_dataset.py b/training/mono/datasets/argovers2_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4963a07bb905a2d5df67ca95358bdcc8bbdd91be
--- /dev/null
+++ b/training/mono/datasets/argovers2_dataset.py
@@ -0,0 +1,33 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+import pickle
+
+class Argovers2Dataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(Argovers2Dataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+
+ def process_depth(self, depth, rgb):
+ depth[depth>65500] = 0
+ depth /= self.metric_scale
+ return depth
+
+
+
+if __name__ == '__main__':
+ from mmcv.utils import Config
+ cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
+ dataset_i = ApolloscapeDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
+ print(dataset_i)
+
\ No newline at end of file
diff --git a/training/mono/datasets/blendedmvg_omni_dataset.py b/training/mono/datasets/blendedmvg_omni_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..b96d7fd9865f8940c5ecc410485bcd88e0436e45
--- /dev/null
+++ b/training/mono/datasets/blendedmvg_omni_dataset.py
@@ -0,0 +1,32 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+
+
+class BlendedMVGOmniDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(BlendedMVGOmniDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+ #self.cap_range = self.depth_range # in meter
+
+ # def __getitem__(self, idx: int) -> dict:
+ # if self.phase == 'test':
+ # return self.get_data_for_test(idx)
+ # else:
+ # return self.get_data_for_trainval(idx)
+
+
+ def process_depth(self, depth: np.array, rgb: np.array) -> np.array:
+ depth[depth>60000] = 0
+ depth = depth / self.metric_scale
+ return depth
diff --git a/training/mono/datasets/cityscapes_dataset.py b/training/mono/datasets/cityscapes_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..61d1bddfe85708ad49f968d35767b41990a131ca
--- /dev/null
+++ b/training/mono/datasets/cityscapes_dataset.py
@@ -0,0 +1,33 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+
+
+class CityscapesDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(CityscapesDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+
+ def process_depth(self, depth, rgb):
+ depth[depth>65500] = 0
+ depth /= self.metric_scale
+ return depth
+
+
+
+if __name__ == '__main__':
+ from mmcv.utils import Config
+ cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
+ dataset_i = ApolloscapeDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
+ print(dataset_i)
+
\ No newline at end of file
diff --git a/training/mono/datasets/ddad_dataset.py b/training/mono/datasets/ddad_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..d913c034cb9d00aac295c68346e1a9ad3ad4117c
--- /dev/null
+++ b/training/mono/datasets/ddad_dataset.py
@@ -0,0 +1,37 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+
+
+class DDADDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(DDADDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+ #self.cap_range = self.depth_range # in meter
+
+
+ def process_depth(self, depth, rgb):
+ depth[depth>65500] = 0
+ depth /= 200.0
+ # depth[(depth>self.cap_range[1]) | (depth dict:
+ """
+ Load meta data information.
+ """
+ if self.meta_data_root is not None and ('meta_data' in anno or 'meta' in anno):
+ meta_data_path = os.path.join(self.meta_data_root, anno['meta_data']) if 'meta_data' in anno else os.path.join(self.meta_data_root, anno['meta'])
+ with open(meta_data_path, 'rb') as f:
+ meta_data = pickle.load(f)
+ meta_data.update(anno)
+ else:
+ meta_data = anno
+
+ # DIML_indoor has no cam_in
+ if 'cam_in' not in meta_data:
+ meta_data['cam_in'] = [1081, 1081, 704, 396]
+ return meta_data
+
+ def process_depth(self, depth, rgb):
+ depth[depth>65500] = 0
+ depth /= self.metric_scale
+ h, w, _ = rgb.shape # to rgb size
+ depth_resize = cv2.resize(depth, (w, h), interpolation=cv2.INTER_NEAREST)
+ return depth_resize
+
+
+
+
+if __name__ == '__main__':
+ from mmcv.utils import Config
+ cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
+ dataset_i = DIMLDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
+ print(dataset_i)
+
\ No newline at end of file
diff --git a/training/mono/datasets/diode_dataset.py b/training/mono/datasets/diode_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..d18541b4029474b3e99b5ee5dcde89d040a5e806
--- /dev/null
+++ b/training/mono/datasets/diode_dataset.py
@@ -0,0 +1,273 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+
+
+def creat_uv_mesh(H, W):
+ y, x = np.meshgrid(np.arange(0, H, dtype=np.float), np.arange(0, W, dtype=np.float), indexing='ij')
+ meshgrid = np.stack((x,y))
+ ones = np.ones((1,H*W), dtype=np.float)
+ xy = meshgrid.reshape(2, -1)
+ return np.concatenate([xy, ones], axis=0)
+
+class DIODEDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(DIODEDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+
+ # meshgrid for depth reprojection
+ self.xy = creat_uv_mesh(768, 1024)
+
+ def get_data_for_test(self, idx: int):
+ anno = self.annotations['files'][idx]
+ meta_data = self.load_meta_data(anno)
+ data_path = self.load_data_path(meta_data)
+ data_batch = self.load_batch(meta_data, data_path)
+ # load data
+ curr_rgb, curr_depth, curr_normal, curr_cam_model = data_batch['curr_rgb'], data_batch['curr_depth'], data_batch['curr_normal'], data_batch['curr_cam_model']
+ ori_curr_intrinsic = meta_data['cam_in']
+
+ # get crop size
+ transform_paras = dict()
+ rgbs, depths, intrinsics, cam_models, _, other_labels, transform_paras = self.img_transforms(
+ images=[curr_rgb,], #+ tmpl_rgbs,
+ labels=[curr_depth, ],
+ intrinsics=[ori_curr_intrinsic, ], # * (len(tmpl_rgbs) + 1),
+ cam_models=[curr_cam_model, ],
+ transform_paras=transform_paras)
+ # depth in original size and orignial metric***
+ depth_out = self.clip_depth(curr_depth) * self.depth_range[1] # self.clip_depth(depths[0]) #
+ inv_depth = self.depth2invdepth(depth_out, np.zeros_like(depth_out, dtype=np.bool))
+ filename = os.path.basename(meta_data['rgb'])[:-4] + '.jpg'
+ curr_intrinsic_mat = self.intrinsics_list2mat(intrinsics[0])
+ ori_curr_intrinsic_mat = self.intrinsics_list2mat(ori_curr_intrinsic)
+
+ pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
+ scale_ratio = transform_paras['label_scale_factor'] if 'label_scale_factor' in transform_paras else 1.0
+ cam_models_stacks = [
+ torch.nn.functional.interpolate(cam_models[0][None, :, :, :], size=(cam_models[0].shape[1]//i, cam_models[0].shape[2]//i), mode='bilinear', align_corners=False).squeeze()
+ for i in [2, 4, 8, 16, 32]
+ ]
+ raw_rgb = torch.from_numpy(curr_rgb)
+ curr_normal = torch.from_numpy(curr_normal.transpose((2,0,1)))
+
+
+ data = dict(input=rgbs[0],
+ target=depth_out,
+ intrinsic=curr_intrinsic_mat,
+ filename=filename,
+ dataset=self.data_name,
+ cam_model=cam_models_stacks,
+ pad=pad,
+ scale=scale_ratio,
+ raw_rgb=raw_rgb,
+ sample_id=idx,
+ data_path=meta_data['rgb'],
+ inv_depth=inv_depth,
+ normal=curr_normal,
+ )
+ return data
+
+
+ # def get_data_for_trainval(self, idx: int):
+ # anno = self.annotations['files'][idx]
+ # meta_data = self.load_meta_data(anno)
+
+ # # curr_rgb_path = os.path.join(self.data_root, meta_data['rgb'])
+ # # curr_depth_path = os.path.join(self.depth_root, meta_data['depth'])
+ # # curr_sem_path = os.path.join(self.sem_root, meta_data['sem']) if self.sem_root is not None and ('sem' in meta_data) and (meta_data['sem'] is not None) else None
+ # # curr_depth_mask_path = os.path.join(self.depth_mask_root, meta_data['depth_mask']) if self.depth_mask_root is not None and ('depth_mask' in meta_data) and (meta_data['depth_mask'] is not None) else None
+ # data_path = self.load_data_path(meta_data)
+ # data_batch = self.load_batch(meta_data, data_path)
+
+ # curr_rgb, curr_depth, curr_normal, curr_sem, curr_cam_model = data_batch['curr_rgb'], data_batch['curr_depth'], data_batch['curr_normal'], data_batch['curr_sem'], data_batch['curr_cam_model']
+
+ # # load data
+ # # curr_intrinsic = meta_data['cam_in']
+ # # curr_rgb, curr_depth = self.load_rgb_depth(curr_rgb_path, curr_depth_path)
+
+ # # # mask the depth
+ # # curr_depth = curr_depth.squeeze()
+ # # depth_mask = self.load_depth_valid_mask(curr_depth_mask_path, curr_depth)
+ # # curr_depth[~depth_mask] = -1
+
+
+ # # # get semantic labels
+ # # curr_sem = self.load_sem_label(curr_sem_path, curr_depth)
+ # # # create camera model
+ # # curr_cam_model = self.create_cam_model(curr_rgb.shape[0], curr_rgb.shape[1], curr_intrinsic)
+
+ # # get crop size
+ # transform_paras = dict(random_crop_size = self.random_crop_size)
+ # rgbs, depths, intrinsics, cam_models, _, other_labels, transform_paras = self.img_transforms(
+ # images=[curr_rgb, ],
+ # labels=[curr_depth, ],
+ # intrinsics=[curr_intrinsic,],
+ # cam_models=[curr_cam_model, ],
+ # other_labels=[curr_sem, ],
+ # transform_paras=transform_paras)
+ # # process sky masks
+ # sem_mask = other_labels[0].int()
+
+ # # clip depth map
+ # depth_out = self.normalize_depth(depths[0])
+ # # set the depth in sky region to the maximum depth
+ # depth_out[sem_mask==142] = -1 #self.depth_normalize[1] - 1e-6
+ # filename = os.path.basename(meta_data['rgb'])
+ # curr_intrinsic_mat = self.intrinsics_list2mat(intrinsics[0])
+ # cam_models_stacks = [
+ # torch.nn.functional.interpolate(cam_models[0][None, :, :, :], size=(cam_models[0].shape[1]//i, cam_models[0].shape[2]//i), mode='bilinear', align_corners=False).squeeze()
+ # for i in [2, 4, 8, 16, 32]
+ # ]
+ # pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
+ # data = dict(input=rgbs[0],
+ # target=depth_out,
+ # intrinsic=curr_intrinsic_mat,
+ # filename=filename,
+ # dataset=self.data_name,
+ # cam_model=cam_models_stacks,
+ # #ref_input=rgbs[1:],
+ # # tmpl_flg=tmpl_annos['w_tmpl'],
+ # pad=torch.tensor(pad),
+ # data_type=[self.data_type, ],
+ # sem_mask=sem_mask.int())
+ # return data
+
+ # def get_data_for_test(self, idx: int):
+ # anno = self.annotations['files'][idx]
+ # meta_data = self.load_meta_data(anno)
+ # curr_rgb_path = os.path.join(self.data_root, meta_data['rgb'])
+ # curr_depth_path = os.path.join(self.depth_root, meta_data['depth'])
+ # curr_depth_mask_path = os.path.join(self.depth_mask_root, meta_data['depth_mask']) if self.depth_mask_root is not None and ('depth_mask' in meta_data) and (meta_data['depth_mask'] is not None) else None
+
+ # # load data
+ # ori_curr_intrinsic = meta_data['cam_in']
+ # curr_rgb, curr_depth = self.load_rgb_depth(curr_rgb_path, curr_depth_path)
+
+ # # mask the depth
+ # curr_depth = curr_depth.squeeze()
+ # depth_mask = self.load_depth_valid_mask(curr_depth_mask_path, curr_depth)
+ # curr_depth[~depth_mask] = -1
+
+ # ori_h, ori_w, _ = curr_rgb.shape
+ # # create camera model
+ # curr_cam_model = self.create_cam_model(curr_rgb.shape[0], curr_rgb.shape[1], ori_curr_intrinsic)
+
+ # # get crop size
+ # transform_paras = dict()
+ # rgbs, depths, intrinsics, cam_models, _, other_labels, transform_paras = self.img_transforms(
+ # images=[curr_rgb,], #+ tmpl_rgbs,
+ # labels=[curr_depth, ],
+ # intrinsics=[ori_curr_intrinsic, ], # * (len(tmpl_rgbs) + 1),
+ # cam_models=[curr_cam_model, ],
+ # transform_paras=transform_paras)
+ # # depth in original size and orignial metric***
+ # depth_out = self.clip_depth(curr_depth) * self.depth_range[1] # self.clip_depth(depths[0]) #
+
+ # filename = os.path.basename(meta_data['rgb'])
+ # curr_intrinsic_mat = self.intrinsics_list2mat(intrinsics[0])
+
+ # pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
+ # scale_ratio = transform_paras['label_scale_factor'] if 'label_scale_factor' in transform_paras else 1.0
+ # cam_models_stacks = [
+ # torch.nn.functional.interpolate(cam_models[0][None, :, :, :], size=(cam_models[0].shape[1]//i, cam_models[0].shape[2]//i), mode='bilinear', align_corners=False).squeeze()
+ # for i in [2, 4, 8, 16, 32]
+ # ]
+ # raw_rgb = torch.from_numpy(curr_rgb)
+ # # rel_pose = torch.from_numpy(tmpl_annos['tmpl_pose_list'][0])
+
+ # data = dict(input=rgbs[0],
+ # target=depth_out,
+ # intrinsic=curr_intrinsic_mat,
+ # filename=filename,
+ # dataset=self.data_name,
+ # cam_model=cam_models_stacks,
+ # pad=pad,
+ # scale=scale_ratio,
+ # raw_rgb=raw_rgb,
+ # sample_id=idx,
+ # data_path=meta_data['rgb'],
+ # )
+ # return data
+
+
+ def load_batch(self, meta_data, data_path):
+ curr_intrinsic = meta_data['cam_in']
+ # load rgb/depth
+ curr_rgb, curr_depth = self.load_rgb_depth(data_path['rgb_path'], data_path['depth_path'])
+ # get semantic labels
+ curr_sem = self.load_sem_label(data_path['sem_path'], curr_depth)
+ # create camera model
+ curr_cam_model = self.create_cam_model(curr_rgb.shape[0], curr_rgb.shape[1], curr_intrinsic)
+ # get normal labels
+
+ try:
+ curr_normal = self.load_norm_label(data_path['normal_path'], H=curr_rgb.shape[0], W=curr_rgb.shape[1], depth=curr_depth, K=curr_intrinsic) # !!! this is diff of BaseDataset
+ except:
+ curr_normal = np.zeros_like(curr_rgb)
+ # get depth mask
+ depth_mask = self.load_depth_valid_mask(data_path['depth_mask_path'])
+ curr_depth[~depth_mask] = -1
+ data_batch = dict(
+ curr_rgb = curr_rgb,
+ curr_depth = curr_depth,
+ curr_sem = curr_sem,
+ curr_normal = curr_normal,
+ curr_cam_model=curr_cam_model,
+ )
+ return data_batch
+
+
+ def load_norm_label(self, norm_path, H, W, depth, K):
+ normal = np.load(norm_path)
+ normal[:,:,1:] *= -1
+ normal = self.align_normal(normal, depth, K, H, W)
+
+ return normal
+
+
+ def process_depth(self, depth, rgb):
+ depth[depth>150] = 0
+ depth[depth<0.1] = 0
+ depth /= self.metric_scale
+ return depth
+
+ def align_normal(self, normal, depth, K, H, W):
+ # inv K
+ K = np.array([[K[0], 0 ,K[2]],
+ [0, K[1], K[3]],
+ [0, 0, 1]])
+ inv_K = np.linalg.inv(K)
+ # reprojection depth to camera points
+ if H == 768 and W == 1024:
+ xy = self.xy
+ else:
+ print('img size no-equal 768x1024')
+ xy = creat_uv_mesh(H, W)
+ points = np.matmul(inv_K[:3, :3], xy).reshape(3, H, W)
+ points = depth * points
+ points = points.transpose((1,2,0))
+
+ # align normal
+ orient_mask = np.sum(normal * points, axis=2) > 0
+ normal[orient_mask] *= -1
+
+ return normal
+
+
+if __name__ == '__main__':
+ from mmcv.utils import Config
+ cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
+ dataset_i = DIODEDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
+ print(dataset_i)
+
\ No newline at end of file
diff --git a/training/mono/datasets/distributed_sampler.py b/training/mono/datasets/distributed_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..da3639964f4391b7b5b308bec64ca1be1f2d45e4
--- /dev/null
+++ b/training/mono/datasets/distributed_sampler.py
@@ -0,0 +1,275 @@
+import numpy as np
+import logging
+import torch.distributed as dist
+import math
+import os
+from mono.utils.comm import get_func, main_process
+from torch.utils.data import ConcatDataset, DataLoader
+import random
+import copy
+import torch
+import logging
+
+
+def build_dataset_n_sampler_with_cfg(cfg, phase):
+ # build data array, similar datasets are organized in the same group
+ datasets_array = build_data_array(cfg, phase)
+ # concatenate datasets with torch.utils.data.ConcatDataset methods
+ dataset_merge = concatenate_datasets(datasets_array)
+ # customerize sampler
+ custom_sampler = CustomerMultiDataSampler(cfg, dataset_merge, phase)
+ return dataset_merge, custom_sampler
+
+class CustomerMultiDataSampler(torch.utils.data.Sampler):
+ """
+ Customerize a sampler method. During this process, the size of some datasets will be tailored or expanded.
+ Such process aims to ensure each group has the same data size.
+ e.g. dataset_list: [[A, B, C], [E, F], M], then group 'A,B,C' (Size(A) + Size(B) + Size(C)) has the same size
+ as to group 'E,F' (Size(E) + Size(F)), so as to 'M'.
+ args:
+ @ cfg: configs for each dataset.
+ @ dataset_merge: merged multiple datasets with the torch.utils.data.ConcatDataset method.
+ @ phase: train/val/test phase.
+ """
+
+ def __init__(self, cfg, dataset_merge, phase):
+ self.cfg = cfg
+ self.world_size = int(os.environ['WORLD_SIZE'])
+ self.phase = phase
+ self.global_rank = cfg.dist_params.global_rank
+ self.dataset_merge = dataset_merge
+ self.logger = logging.getLogger()
+ if main_process():
+ self.logger.info(f'Initilized CustomerMultiDataSampler for {phase}.')
+ self.random_seed = 136
+ self.random_seed_cp = 639
+
+ def __iter__(self):
+ self.create_samplers()
+ self.logger.info("Sample list of {} in rank {} is: {}".format(self.phase, self.global_rank, ' '.join(map(str, self.sample_indices_array[-20: -10]))))
+ # subsample, each rank sample a subset for training.
+ rank_offset = self.each_gpu_size * self.global_rank
+ rank_indices = self.sample_indices_array[rank_offset : rank_offset + self.each_gpu_size]
+
+ assert rank_indices.size == self.each_gpu_size
+
+ for id in rank_indices:
+ yield id
+
+ def __len__(self):
+ return self.total_dist_size
+
+ def create_samplers(self):
+ # sample idx for each dataset, idx value should not exceed the size of data,
+ # i.e. 0 <= idx < len(data_size)
+ #self.samples_mat = []
+ self.indices_mat = []
+ # size expanded, idx cumulative aggregrated for calling
+ self.indices_expand_mat = []
+
+ # max group size, each group may consists of multiple similar datasets
+ max_group_size = max([len(i) for i in self.dataset_merge.datasets])
+
+ dataset_cumulative_sizes = [0] + self.dataset_merge.cumulative_sizes
+
+ for gi, dataset_group in enumerate(self.dataset_merge.datasets):
+ # the merged dataset consists of multiple grouped datasets
+ samples_group = []
+ indices_expand_group = []
+ indices_group = []
+
+ # to ensure each group has the same size, group with less data has to duplicate its sample list for 'cp_times' times
+ cp_times = max_group_size / len(dataset_group)
+
+ # adjust each group to ensure they have the same data size
+ group_cumulative_sizes = [0] + dataset_group.cumulative_sizes
+ expand_indices_sizes = (np.array(group_cumulative_sizes) * cp_times).astype(np.int)
+ expand_indices_sizes[-1] = max_group_size
+ # datasets in the same group have to expand its sample list
+ expand_indices_sizes = expand_indices_sizes[1:] - expand_indices_sizes[:-1]
+
+ for di, dataset_i in enumerate(dataset_group.datasets):
+ # datasets residing in each group may have similar features
+ # samples indices list
+ dataset_i_ori_sample_list = self.dataset_merge.datasets[gi].datasets[di].sample_list
+ if self.phase == 'train':
+ #sample_list_i = random.sample(dataset_i_ori_sample_list, len(dataset_i_ori_sample_list))
+ sample_list_i = dataset_i_ori_sample_list
+ else:
+ # no shuffle in val or test
+ sample_list_i = dataset_i_ori_sample_list
+ #samples_group.append(sample_list_i)
+
+ # expand the sample list for each dataset
+ expand_size_i = expand_indices_sizes[di]
+ indices_expand_list = copy.deepcopy(sample_list_i)
+
+ for i in range(int(cp_times)-1):
+ #indices_expand_list += random.sample(sample_list_i, len(dataset_i))
+ indices_expand_list += sample_list_i
+ random.seed(self.random_seed_cp)
+ indices_expand_list += random.sample(sample_list_i, len(dataset_i))[:expand_size_i % len(dataset_i)]
+ # adjust indices value
+ indices_expand_list = np.array(indices_expand_list) + dataset_cumulative_sizes[gi] + group_cumulative_sizes[di]
+ indices_list = np.array(sample_list_i) + dataset_cumulative_sizes[gi] + group_cumulative_sizes[di]
+
+ # the expanded sample list for dataset_i
+ indices_expand_group.append(indices_expand_list)
+ # the original sample list for the dataset_i
+ indices_group.append(indices_list)
+
+ if main_process():
+ self.logger.info(f'"{dataset_i.data_name}", {self.phase} set in group {gi}: ' +
+ f'expand size {len(sample_list_i)} --->>>---, {expand_size_i}')
+
+ concat_group = np.concatenate(indices_expand_group)
+ # shuffle the grouped datasets samples, e.g. each group data is [a1, a2, a3, b1, b2, b3, b4, c1, c2], the shuffled one, maybe, is [a3, b1, b2, b3, b4, c1,...]
+ np.random.seed(self.random_seed)
+ if self.phase == 'train':
+ np.random.shuffle(concat_group)
+ self.indices_expand_mat.append(concat_group)
+ self.indices_mat.append(np.concatenate(indices_group))
+
+ # create sample list
+ if "train" in self.phase:
+ # data groups are cross sorted, i.e. [A, B, C, A, B, C....]
+ self.sample_indices_array = np.array(self.indices_expand_mat).transpose(1, 0).reshape(-1)
+ self.total_indices_size = max_group_size * len(self.dataset_merge.datasets)
+ else:
+ self.sample_indices_array = np.concatenate(self.indices_mat[:])
+ self.total_indices_size = self.sample_indices_array.size
+
+ self.total_sample_size = len(self.dataset_merge)
+ self.each_gpu_size = int(np.ceil(self.total_indices_size * 1.0 / self.world_size)) # ignore some residual samples
+ self.total_dist_size = self.each_gpu_size * self.world_size
+ # add extra samples to make it evenly divisible
+ diff_size = int(self.total_dist_size - self.total_indices_size) # int(self.total_dist_size - self.total_sample_size)
+ if diff_size > 0:
+ self.sample_indices_array = np.append(self.sample_indices_array, self.sample_indices_array[:diff_size])
+ #if main_process():
+ self.logger.info(f'Expanded data size in merged dataset: {self.total_sample_size}, adjusted data size for distributed running: {self.total_dist_size}')
+ self.random_seed += 413
+ self.random_seed_cp += 377
+
+
+def build_data_array(cfg, phase):
+ """
+ Construct data repo with cfg. In cfg, there is a data name array, which encloses the name of each data.
+ Each data name links to a data config file. With this config file, dataset can be constructed.
+ e.g. [['A', 'B', 'C'], ['E', 'F'], 'M']. Each letter indicates a dataset.
+ """
+
+ datasets_array = []
+ data_array_names_for_log = []
+
+ dataname_array = cfg.data_array
+ for group_i in dataname_array:
+ dataset_group_i = []
+ data_group_i_names_for_log = []
+ if not isinstance(group_i, list):
+ group_i = [group_i, ]
+ for data_i in group_i:
+ if not isinstance(data_i, dict):
+ raise TypeError(f'data name must be a dict, but got {type(data_i)}')
+ # each data only can employ a single dataset config
+ assert len(data_i.values()) == 1
+ if list(data_i.values())[0] not in cfg:
+ raise RuntimeError(f'cannot find the data config for {data_i}')
+
+ # dataset configure for data i
+ #data_i_cfg = cfg[data_i]
+ args = copy.deepcopy(cfg) #data_i_cfg.copy()
+ data_i_cfg_name = list(data_i.values())[0]
+ data_i_db_info_name = list(data_i.keys())[0]
+ data_i_db_info = cfg.db_info[data_i_db_info_name]
+
+ # Online evaluation using only metric datasets
+ # if phase == 'val' and 'exclude' in cfg.evaluation \
+ # and data_i_db_info_name in cfg.evaluation.exclude:
+ # continue
+
+ # dataset lib name
+ obj_name = cfg[data_i_cfg_name]['lib']
+ obj_path = os.path.dirname(__file__).split(os.getcwd() + '/')[-1].replace('/', '.') + '.' + obj_name
+ obj_cls = get_func(obj_path)
+ if obj_cls is None:
+ raise KeyError(f'{obj_name} is not in .data')
+
+ dataset_i = obj_cls(
+ args[data_i_cfg_name],
+ phase,
+ db_info=data_i_db_info,
+ **cfg.data_basic)
+ # if 'Taskonomy' not in data_i:
+ # print('>>>>>>>>>>ditributed_sampler LN189', dataset_i.data_name, dataset_i.annotations['files'][0]['rgb'].split('/')[-1],
+ # dataset_i.annotations['files'][1000]['rgb'].split('/')[-1], dataset_i.annotations['files'][3000]['rgb'].split('/')[-1])
+ # else:
+ # print('>>>>>>>>>>ditributed_sampler LN189', dataset_i.data_name, dataset_i.annotations['files'][0]['meta_data'].split('/')[-1],
+ # dataset_i.annotations['files'][1000]['meta_data'].split('/')[-1], dataset_i.annotations['files'][3000]['meta_data'].split('/')[-1])
+ dataset_group_i.append(dataset_i)
+ # get data name for log
+ data_group_i_names_for_log.append(data_i_db_info_name)
+
+ datasets_array.append(dataset_group_i)
+ data_array_names_for_log.append(data_group_i_names_for_log)
+
+ if main_process():
+ logger = logging.getLogger()
+ logger.info(f'{phase}: data array ({data_array_names_for_log}) has been constructed.')
+ return datasets_array
+
+def concatenate_datasets(datasets_array):
+ """
+ Merge grouped datasets to a single one.
+ args:
+ @ dataset_list: the list of constructed dataset.
+ """
+ #max_size = 0
+ dataset_merge = []
+ for group in datasets_array:
+ group_dataset = ConcatDataset(group)
+ group_size = len(group_dataset)
+ #max_size = max_size if group_size < max_size else group_size
+ dataset_merge.append(group_dataset)
+ return ConcatDataset(dataset_merge)
+
+
+def log_canonical_transfer_info(cfg):
+ logger = logging.getLogger()
+ data = []
+ canonical_focal_length = cfg.data_basic.canonical_space.focal_length
+ canonical_size = cfg.data_basic.canonical_space.img_size
+ for group_i in cfg.data_array:
+ if not isinstance(group_i, list):
+ group_i = [group_i, ]
+ for data_i in group_i:
+ if not isinstance(data_i, dict):
+ raise TypeError(f'data name must be a dict, but got {type(data_i)}')
+ assert len(data_i.values()) == 1
+ if list(data_i.values())[0] not in cfg:
+ raise RuntimeError(f'cannot find the data config for {data_i.values()}')
+ if list(data_i.values())[0] not in data:
+ data.append(list(data_i.values())[0])
+
+ logger.info('>>>>>>>>>>>>>>Some data transfer details during augmentation.>>>>>>>>>>>>>>')
+ for data_i in data:
+ data_i_cfg = cfg[data_i]
+ if type(data_i_cfg.original_focal_length) != tuple:
+ ori_focal = (data_i_cfg.original_focal_length, )
+ else:
+ ori_focal = data_i_cfg.original_focal_length
+
+ log_str = '%s transfer details: \n' % data_i
+ for ori_f in ori_focal:
+ # to canonical space
+ scalor = canonical_focal_length / ori_f
+ img_size = (data_i_cfg.original_size[0]*scalor, data_i_cfg.original_size[1]*scalor)
+ log_str += 'To canonical space: focal length, %f -> %f; size, %s -> %s\n' %(ori_f, canonical_focal_length, data_i_cfg.original_size, img_size)
+
+ # random resize in augmentaiton
+ resize_range = data_i_cfg.data.train.pipeline[1].ratio_range
+ resize_low = (img_size[0]*resize_range[0], img_size[1]*resize_range[0])
+ resize_up = (img_size[0]*resize_range[1], img_size[1]*resize_range[1])
+ log_str += 'Random resize bound: %s ~ %s; \n' %(resize_low, resize_up)
+
+ logger.info(log_str)
\ No newline at end of file
diff --git a/training/mono/datasets/drivingstereo_dataset.py b/training/mono/datasets/drivingstereo_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce6aa79bb3054a4909e87fe010ef22fe01736b71
--- /dev/null
+++ b/training/mono/datasets/drivingstereo_dataset.py
@@ -0,0 +1,35 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+
+
+class DrivingStereoDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(DrivingStereoDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+
+
+
+ def process_depth(self, depth, rgb):
+ depth[depth>65500] = 0
+ depth /= self.metric_scale
+ return depth
+
+
+
+if __name__ == '__main__':
+ from mmcv.utils import Config
+ cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
+ dataset_i = ApolloscapeDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
+ print(dataset_i)
+
\ No newline at end of file
diff --git a/training/mono/datasets/dsec_dataset.py b/training/mono/datasets/dsec_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..1029c71c69c2af99a0c9c119332a4d7ee29dd366
--- /dev/null
+++ b/training/mono/datasets/dsec_dataset.py
@@ -0,0 +1,35 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+
+
+class DSECDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(DSECDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+
+
+
+ def process_depth(self, depth, rgb):
+ depth[depth>65500] = 0
+ depth /= self.metric_scale
+ return depth
+
+
+
+if __name__ == '__main__':
+ from mmcv.utils import Config
+ cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
+ dataset_i = ApolloscapeDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
+ print(dataset_i)
+
\ No newline at end of file
diff --git a/training/mono/datasets/eth3d_dataset.py b/training/mono/datasets/eth3d_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..380e6fd6138ea05841efd886cbafd75d0f37adb7
--- /dev/null
+++ b/training/mono/datasets/eth3d_dataset.py
@@ -0,0 +1,94 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+
+
+class ETH3DDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(ETH3DDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+
+ def __getitem__(self, idx):
+ anno = self.annotations['files'][idx]
+ curr_rgb_path = os.path.join(self.data_root, anno['rgb_path'])
+ curr_depth_path = os.path.join(self.depth_root, anno['depth_path'])
+ meta_data = self.load_meta_data(anno)
+ ori_curr_intrinsic = [2000, 2000, 3024, 2016] #meta_data['cam_in']
+
+ curr_rgb = cv2.imread(curr_rgb_path) # [r, g, b]
+ with open(curr_depth_path, 'r') as f:
+ imgfile = np.fromfile(f, np.float32)
+ curr_depth = imgfile.reshape((4032, 6048))
+ curr_depth[curr_depth>100] = 0
+
+ #curr_rgb, curr_depth = self.load_rgb_depth(curr_rgb_path, curr_depth_path)
+ # curr_rgb = cv2.resize(curr_rgb, dsize=(3024, 2016), interpolation=cv2.INTER_LINEAR)
+ # curr_depth = cv2.resize(curr_depth, dsize=(3024, 2016), interpolation=cv2.INTER_LINEAR)
+ # ori_curr_intrinsic = [i//2 for i in ori_curr_intrinsic]
+
+ ori_h, ori_w, _ = curr_rgb.shape
+ # create camera model
+ curr_cam_model = self.create_cam_model(curr_rgb.shape[0], curr_rgb.shape[1], ori_curr_intrinsic)
+ # load tmpl rgb info
+ # tmpl_annos = self.load_tmpl_annos(anno, curr_rgb, meta_data)
+ # tmpl_rgb = tmpl_annos['tmpl_rgb_list'] # list of reference rgbs
+
+ transform_paras = dict()
+ rgbs, depths, intrinsics, cam_models, _, other_labels, transform_paras = self.img_transforms(
+ images=[curr_rgb, ],
+ labels=[curr_depth, ],
+ intrinsics=[ori_curr_intrinsic,],
+ cam_models=[curr_cam_model, ],
+ transform_paras=transform_paras)
+ # depth in original size
+ depth_out = self.clip_depth(curr_depth) * self.depth_range[1]
+
+ filename = os.path.basename(anno['rgb_path'])
+ curr_intrinsic_mat = self.intrinsics_list2mat(intrinsics[0])
+
+ pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
+ scale_ratio = transform_paras['label_scale_factor'] if 'label_scale_factor' in transform_paras else 1.0
+ cam_models_stacks = [
+ torch.nn.functional.interpolate(cam_models[0][None, :, :, :], size=(cam_models[0].shape[1]//i, cam_models[0].shape[2]//i), mode='bilinear', align_corners=False).squeeze()
+ for i in [2, 4, 8, 16, 32]
+ ]
+ raw_rgb = torch.from_numpy(curr_rgb)
+ data = dict(input=rgbs[0],
+ target=depth_out,
+ intrinsic=curr_intrinsic_mat,
+ filename=filename,
+ dataset=self.data_name,
+ cam_model=cam_models_stacks,
+ ref_input=rgbs[1:],
+ tmpl_flg=False,
+ pad=pad,
+ scale=scale_ratio,
+ raw_rgb=raw_rgb,
+ normal = np.zeros_like(curr_rgb.transpose((2,0,1))),
+ #stereo_depth=torch.zeros_like(depth_out)
+ )
+ return data
+
+ def process_depth(self, depth):
+ depth[depth>65500] = 0
+ depth /= self.metric_scale
+ return depth
+
+
+
+if __name__ == '__main__':
+ from mmcv.utils import Config
+ cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
+ dataset_i = NYUDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
+ print(dataset_i)
+
\ No newline at end of file
diff --git a/training/mono/datasets/fisheye_dataset.py b/training/mono/datasets/fisheye_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9c2d75851451ea3a6dbd5e5c79cc44a80fe7402
--- /dev/null
+++ b/training/mono/datasets/fisheye_dataset.py
@@ -0,0 +1,76 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+
+
+class FisheyeDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(FisheyeDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+
+ def load_data(self, path: str, is_rgb_img: bool=False):
+ if not os.path.exists(path):
+ self.logger.info(f'>>>>{path} does not exist.')
+ # raise RuntimeError(f'{path} does not exist.')
+
+ data_type = os.path.splitext(path)[-1]
+ if data_type in self.img_file_type:
+ if is_rgb_img:
+ data = cv2.imread(path)
+ else:
+ data = cv2.imread(path, -1)
+ data[data>65500] = 0
+ data &= 0x7FFF
+
+ elif data_type in self.np_file_type:
+ data = np.load(path)
+ else:
+ raise RuntimeError(f'{data_type} is not supported in current version.')
+
+ return data.squeeze()
+
+ def load_batch(self, meta_data, data_path):
+ curr_intrinsic = meta_data['cam_in']
+ # load rgb/depth
+ curr_rgb, curr_depth = self.load_rgb_depth(data_path['rgb_path'], data_path['depth_path'])
+ # get semantic labels
+ curr_sem = self.load_sem_label(data_path['sem_path'], curr_depth)
+ # create camera model
+ curr_cam_model = self.create_cam_model(curr_rgb.shape[0], curr_rgb.shape[1], curr_intrinsic)
+ # get normal labels
+ curr_normal = self.load_norm_label(data_path['normal_path'], H=curr_rgb.shape[0], W=curr_rgb.shape[1])
+ # get depth mask
+ depth_mask = self.load_depth_valid_mask(data_path['depth_mask_path'])[:, :, :]
+
+ # with masks from andy
+ curr_depth[~(depth_mask[:, :, 0])] = -1
+ curr_rgb[~(depth_mask[:, :, :])] = 0
+
+ # get stereo depth
+ curr_stereo_depth = self.load_stereo_depth_label(data_path['disp_path'], H=curr_rgb.shape[0], W=curr_rgb.shape[1])
+
+ data_batch = dict(
+ curr_rgb = curr_rgb,
+ curr_depth = curr_depth,
+ curr_sem = curr_sem,
+ curr_normal = curr_normal,
+ curr_cam_model=curr_cam_model,
+ curr_stereo_depth=curr_stereo_depth,
+ )
+ return data_batch
+
+
+ def process_depth(self, depth, rgb):
+
+ depth /= self.metric_scale
+ return depth
diff --git a/training/mono/datasets/hm3d_dataset.py b/training/mono/datasets/hm3d_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..d143453c9e16f19bfe778a1d358207e9bd2b8d57
--- /dev/null
+++ b/training/mono/datasets/hm3d_dataset.py
@@ -0,0 +1,35 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from PIL import Image
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+
+
+class HM3DDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(HM3DDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+ #self.cap_range = self.depth_range # in meter
+
+ def load_norm_label(self, norm_path, H, W):
+ with open(norm_path, 'rb') as f:
+ normal = Image.open(f)
+ normal = np.array(normal.convert(normal.mode), dtype=np.uint8)
+ invalid_mask = np.all(normal == 128, axis=2)
+ normal = normal.astype(np.float64) / 255.0 * 2 - 1
+ normal[invalid_mask, :] = 0
+ return normal
+
+ def process_depth(self, depth: np.array, rgb: np.array) -> np.array:
+ depth[depth>60000] = 0
+ depth = depth / self.metric_scale
+ return depth
diff --git a/training/mono/datasets/hypersim_dataset.py b/training/mono/datasets/hypersim_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..d255fceb11f7e93edf910431f73942367ce0642c
--- /dev/null
+++ b/training/mono/datasets/hypersim_dataset.py
@@ -0,0 +1,141 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from PIL import Image
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+import h5py
+
+def creat_uv_mesh(H, W):
+ y, x = np.meshgrid(np.arange(0, H, dtype=np.float), np.arange(0, W, dtype=np.float), indexing='ij')
+ meshgrid = np.stack((x,y))
+ ones = np.ones((1,H*W), dtype=np.float)
+ xy = meshgrid.reshape(2, -1)
+ return np.concatenate([xy, ones], axis=0)
+
+class HypersimDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(HypersimDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+ #self.cap_range = self.depth_range # in meter
+ # init uv
+
+ # meshgrid for depth reprojection
+ self.xy = creat_uv_mesh(768, 1024)
+
+ def load_batch(self, meta_data, data_path):
+ curr_intrinsic = meta_data['cam_in']
+ # load rgb/depth
+ curr_rgb, curr_depth = self.load_rgb_depth(data_path['rgb_path'], data_path['depth_path'])
+ # get semantic labels
+ curr_sem = self.load_sem_label(data_path['sem_path'], curr_depth)
+ # create camera model
+ curr_cam_model = self.create_cam_model(curr_rgb.shape[0], curr_rgb.shape[1], curr_intrinsic)
+ # get normal labels
+ curr_normal = self.load_norm_label(data_path['normal_path'], H=curr_rgb.shape[0], W=curr_rgb.shape[1], depth=curr_depth, K=curr_intrinsic) # !!! this is diff of BaseDataset
+ # get depth mask
+ depth_mask = self.load_depth_valid_mask(data_path['depth_mask_path'])
+ curr_depth[~depth_mask] = -1
+ data_batch = dict(
+ curr_rgb = curr_rgb,
+ curr_depth = curr_depth,
+ curr_sem = curr_sem,
+ curr_normal = curr_normal,
+ curr_cam_model=curr_cam_model,
+ )
+ return data_batch
+
+ def load_data_path(self, meta_data):
+ # 'rgbs': {'rgb_color': 'Hypersim/data/ai_001_001/images/scene_cam_00_final_preview/frame.0008.color.jpg',
+ # 'rgb_gamma': 'Hypersim/data/ai_001_001/images/scene_cam_00_final_preview/frame.0008.gamma.jpg',
+ # 'rgb_tonemap': 'Hypersim/data/ai_001_001/images/scene_cam_00_final_preview/frame.0008.tonemap.jpg',
+ # 'rgb_raw': 'Hypersim/data/ai_001_001/images/scene_cam_00_final_hdf5/frame.0008.color.hdf5'}
+ meta_data['rgb'] = meta_data['rgbs']['rgb_color'] # this is diff of BaseDataset
+ curr_rgb_path = os.path.join(self.data_root, meta_data['rgb'])
+ curr_depth_path = os.path.join(self.depth_root, meta_data['depth'])
+ curr_sem_path = os.path.join(self.sem_root, meta_data['sem']) \
+ if self.sem_root is not None and ('sem' in meta_data) and (meta_data['sem'] is not None) \
+ else None
+ curr_norm_path = os.path.join(self.norm_root, meta_data['normal']) \
+ if ('normal' in meta_data) and (meta_data['normal'] is not None) and (self.norm_root is not None) \
+ else None
+ curr_depth_mask_path = os.path.join(self.depth_mask_root, meta_data['depth_mask']) \
+ if self.depth_mask_root is not None and ('depth_mask' in meta_data) and (meta_data['depth_mask'] is not None) \
+ else None
+
+ data_path=dict(
+ rgb_path=curr_rgb_path,
+ depth_path=curr_depth_path,
+ sem_path=curr_sem_path,
+ normal_path=curr_norm_path,
+ depth_mask_path=curr_depth_mask_path,
+ )
+ return data_path
+
+ def load_rgb_depth(self, rgb_path: str, depth_path: str):
+ """
+ Load the rgb and depth map with the paths.
+ """
+ rgb = self.load_data(rgb_path, is_rgb_img=True)
+ if rgb is None:
+ self.logger.info(f'>>>>{rgb_path} has errors.')
+
+ # depth = self.load_data(depth_path)
+ with h5py.File(depth_path, "r") as f: depth = f["dataset"][:]
+ np.nan_to_num(depth, copy=False, nan=0) # fill nan in gt
+ if depth is None:
+ self.logger.info(f'{depth_path} has errors.')
+
+ depth = depth.astype(np.float)
+
+ depth = self.process_depth(depth, rgb)
+ return rgb, depth
+
+
+ def load_norm_label(self, norm_path, H, W, depth, K):
+ with h5py.File(norm_path, "r") as f:
+ normal = f["dataset"][:]
+ np.nan_to_num(normal, copy=False, nan=0)
+ normal[:,:,1:] *= -1
+ normal = normal.astype(np.float)
+
+ return self.align_normal(normal, depth, K, H, W)
+
+ def process_depth(self, depth: np.array, rgb: np.array) -> np.array:
+ depth[depth>60000] = 0
+ depth = depth / self.metric_scale
+ return depth
+
+ def align_normal(self, normal, depth, K, H, W):
+ '''
+ Orientation of surface normals in hypersim is not always consistent
+ see https://github.com/apple/ml-hypersim/issues/26
+ '''
+ # inv K
+ K = np.array([[K[0], 0 ,K[2]],
+ [0, K[1], K[3]],
+ [0, 0, 1]])
+ inv_K = np.linalg.inv(K)
+ # reprojection depth to camera points
+ if H == 768 and W == 1024:
+ xy = self.xy
+ else:
+ print('img size no-equal 768x1024')
+ xy = creat_uv_mesh(H, W)
+ points = np.matmul(inv_K[:3, :3], xy).reshape(3, H, W)
+ points = depth * points
+ points = points.transpose((1,2,0))
+
+ # align normal
+ orient_mask = np.sum(normal * points, axis=2) > 0
+ normal[orient_mask] *= -1
+
+ return normal
\ No newline at end of file
diff --git a/training/mono/datasets/ibims_dataset.py b/training/mono/datasets/ibims_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..50a5318ce7e75afa18df9ec19360bd50eada5fdf
--- /dev/null
+++ b/training/mono/datasets/ibims_dataset.py
@@ -0,0 +1,92 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+
+
+class IBIMSDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(IBIMSDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+
+ self.avg = torch.nn.AvgPool2d(kernel_size=7, stride=1, ceil_mode=False, count_include_pad=True, divisor_override=None)
+ self.unfold = torch.nn.Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
+ self.pad = torch.nn.ZeroPad2d(3)
+
+
+ def process_depth(self, depth, rgb):
+ depth[depth>50000] = 0
+ depth /= self.metric_scale
+ return depth
+
+ def load_batch(self, meta_data, data_path):
+ curr_intrinsic = meta_data['cam_in']
+ # load rgb/depth
+ curr_rgb, curr_depth = self.load_rgb_depth(data_path['rgb_path'], data_path['depth_path'])
+ # get semantic labels
+ curr_sem = self.load_sem_label(data_path['sem_path'], curr_depth)
+ # create camera model
+ curr_cam_model = self.create_cam_model(curr_rgb.shape[0], curr_rgb.shape[1], curr_intrinsic)
+ # get normal labels
+ curr_normal = self.load_norm_label(data_path['normal_path'], H=curr_rgb.shape[0], W=curr_rgb.shape[1], depth=curr_depth, K=curr_intrinsic) # !!! this is diff of BaseDataset
+ # get depth mask
+ depth_mask = self.load_depth_valid_mask(data_path['depth_mask_path'])
+ curr_depth[~depth_mask] = -1
+ data_batch = dict(
+ curr_rgb = curr_rgb,
+ curr_depth = curr_depth,
+ curr_sem = curr_sem,
+ curr_normal = curr_normal,
+ curr_cam_model=curr_cam_model,
+ )
+ return data_batch
+
+ def load_norm_label(self, norm_path, H, W, depth, K):
+ depth = torch.from_numpy(depth).squeeze()
+ K = torch.Tensor([[K[0], 0 ,K[2]],
+ [0, K[1], K[3]],
+ [0, 0, 1]])
+ K_inv = K.inverse()
+
+ y, x = torch.meshgrid([torch.arange(0, 480, dtype=torch.float32),
+ torch.arange(0, 640, dtype=torch.float32)], indexing='ij')
+ x = x.reshape(1, 480*640)
+ y = y.reshape(1, 480*640)
+ ones = torch.ones_like(x)
+ coord_2d = torch.cat((x, y, ones), dim=0)
+
+ coord_3d = torch.matmul(K_inv, coord_2d).view(3, 480, 640)
+ coord_3d = (coord_3d * depth[None, :])[None, :]
+ coord_3d_mean = self.avg(coord_3d)
+
+ uf_coord_3d = self.unfold(coord_3d.permute(1, 0, 2, 3))
+ coord_3d_decenter = uf_coord_3d - coord_3d_mean.view(3, 1, (480-6)*(640-6))
+ coord_3d_decenter = coord_3d_decenter.permute(2, 0, 1)
+ cov = torch.bmm(coord_3d_decenter, coord_3d_decenter.permute(0, 2, 1))
+
+ eig = torch.linalg.eigh(cov)
+ #svd = torch.linalg.svd(coord_3d_decenter)
+ normal = (eig[1])[:, :, 0].float()
+ #normal = (svd[1])[:, 2, :]
+ normal = self.pad(normal.permute(1, 0).view(1, 3, (480-6), (640-6)))
+
+ orient_mask = (torch.sum(normal * coord_3d, axis=1) < 0).unsqueeze(1)
+ normal = normal * orient_mask - normal * (~orient_mask)
+ gt_normal = normal.squeeze().permute(1, 2, 0).numpy()
+ return gt_normal
+
+if __name__ == '__main__':
+ from mmcv.utils import Config
+ cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
+ dataset_i = IBIMSDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
+ print(dataset_i)
+
\ No newline at end of file
diff --git a/training/mono/datasets/kitti_dataset.py b/training/mono/datasets/kitti_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2962f712ab348af8d2d3b46a2412b565a5446be2
--- /dev/null
+++ b/training/mono/datasets/kitti_dataset.py
@@ -0,0 +1,190 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+
+
+class KITTIDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(KITTIDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+
+ def get_data_for_trainval(self, idx: int):
+ anno = self.annotations['files'][idx]
+ meta_data = self.load_meta_data(anno)
+
+ data_path = self.load_data_path(meta_data)
+ data_batch = self.load_batch(meta_data, data_path)
+ # if data_path['sem_path'] is not None:
+ # print(self.data_name)
+
+ curr_rgb, curr_depth, curr_normal, curr_sem, curr_cam_model = data_batch['curr_rgb'], data_batch['curr_depth'], data_batch['curr_normal'], data_batch['curr_sem'], data_batch['curr_cam_model']
+ #curr_stereo_depth = data_batch['curr_stereo_depth']
+
+ th = 352 # target size for bottom cropping, a common practice for kitti training
+ tw = 1216
+
+ ch = curr_rgb.shape[0]
+ cw = curr_rgb.shape[1]
+
+ h_start = ch - th
+ w_start = (cw - tw) // 2
+ w_end = w_start + tw
+
+ curr_intrinsic = meta_data['cam_in']
+
+ curr_rgb = curr_rgb[h_start:, w_start:w_end, :]
+ curr_depth = curr_depth[h_start:, w_start:w_end]
+
+ curr_normal = curr_normal[h_start:, w_start:w_end, :]
+ curr_sem = curr_sem[h_start:, w_start:w_end]
+
+ curr_intrinsic[2] = curr_intrinsic[2] - w_start # cw
+ curr_intrinsic[3] = curr_intrinsic[3] - h_start # ch
+
+ # A patch for stereo depth dataloader (no need to modify specific datasets)
+ if 'curr_stereo_depth' in data_batch.keys():
+ curr_stereo_depth = data_batch['curr_stereo_depth']
+ else:
+ curr_stereo_depth = self.load_stereo_depth_label(None, H=curr_rgb.shape[0], W=curr_rgb.shape[1])
+
+
+ # data augmentation
+ transform_paras = dict(random_crop_size = self.random_crop_size) # dict()
+ assert curr_rgb.shape[:2] == curr_depth.shape == curr_normal.shape[:2] == curr_sem.shape
+ rgbs, depths, intrinsics, cam_models, normals, other_labels, transform_paras = self.img_transforms(
+ images=[curr_rgb, ],
+ labels=[curr_depth, ],
+ intrinsics=[curr_intrinsic,],
+ cam_models=[curr_cam_model, ],
+ normals = [curr_normal, ],
+ other_labels=[curr_sem, curr_stereo_depth],
+ transform_paras=transform_paras)
+ # process sky masks
+ sem_mask = other_labels[0].int()
+ # clip depth map
+ depth_out = self.normalize_depth(depths[0])
+ # set the depth of sky region to the invalid
+ depth_out[sem_mask==142] = -1 # self.depth_normalize[1] - 1e-6
+ # get inverse depth
+ inv_depth = self.depth2invdepth(depth_out, sem_mask==142)
+ filename = os.path.basename(meta_data['rgb'])[:-4] + '.jpg'
+ curr_intrinsic_mat = self.intrinsics_list2mat(intrinsics[0])
+ cam_models_stacks = [
+ torch.nn.functional.interpolate(cam_models[0][None, :, :, :], size=(cam_models[0].shape[1]//i, cam_models[0].shape[2]//i), mode='bilinear', align_corners=False).squeeze()
+ for i in [2, 4, 8, 16, 32]
+ ]
+
+ # stereo_depth
+ stereo_depth_pre_trans = other_labels[1] * (other_labels[1] > 0.3) * (other_labels[1] < 200)
+ stereo_depth = stereo_depth_pre_trans * transform_paras['label_scale_factor']
+ stereo_depth = self.normalize_depth(stereo_depth)
+
+ pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
+ data = dict(input=rgbs[0],
+ target=depth_out,
+ intrinsic=curr_intrinsic_mat,
+ filename=filename,
+ dataset=self.data_name,
+ cam_model=cam_models_stacks,
+ pad=torch.tensor(pad),
+ data_type=[self.data_type, ],
+ sem_mask=sem_mask.int(),
+ stereo_depth= stereo_depth,
+ normal=normals[0],
+ inv_depth=inv_depth,
+ scale=transform_paras['label_scale_factor'])
+ return data
+
+
+ def get_data_for_test(self, idx: int):
+ anno = self.annotations['files'][idx]
+ meta_data = self.load_meta_data(anno)
+ curr_rgb_path = os.path.join(self.data_root, meta_data['rgb'])
+ curr_depth_path = os.path.join(self.depth_root, meta_data['depth'])
+ # load data
+ ori_curr_intrinsic = meta_data['cam_in']
+ curr_rgb, curr_depth = self.load_rgb_depth(curr_rgb_path, curr_depth_path)
+ # crop rgb/depth
+ curr_rgb = curr_rgb[:, 43: 1197, :]
+ curr_depth = curr_depth[:, 43: 1197]
+
+ ori_h, ori_w, _ = curr_rgb.shape
+ # create camera model
+ curr_cam_model = self.create_cam_model(curr_rgb.shape[0], curr_rgb.shape[1], ori_curr_intrinsic)
+ # load tmpl rgb info
+ # tmpl_annos = self.load_tmpl_image_pose(curr_rgb, meta_data)
+ # tmpl_rgbs = tmpl_annos['tmpl_rgb_list'] # list of reference rgbs
+
+ # get crop size
+ transform_paras = dict()
+ rgbs, depths, intrinsics, cam_models, _, other_labels, transform_paras = self.img_transforms(
+ images=[curr_rgb,], #+ tmpl_rgbs,
+ labels=[curr_depth, ],
+ intrinsics=[ori_curr_intrinsic, ], # * (len(tmpl_rgbs) + 1),
+ cam_models=[curr_cam_model, ],
+ transform_paras=transform_paras)
+
+ # depth in original size and orignial metric***
+ depth_out = self.clip_depth(curr_depth) * self.depth_range[1] # self.clip_depth(depths[0]) #
+
+ filename = os.path.basename(meta_data['rgb'])
+ curr_intrinsic_mat = self.intrinsics_list2mat(intrinsics[0])
+
+ pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
+ scale_ratio = transform_paras['label_scale_factor'] if 'label_scale_factor' in transform_paras else 1.0
+ cam_models_stacks = [
+ torch.nn.functional.interpolate(cam_models[0][None, :, :, :], size=(cam_models[0].shape[1]//i, cam_models[0].shape[2]//i), mode='bilinear', align_corners=False).squeeze()
+ for i in [2, 4, 8, 16, 32]
+ ]
+ raw_rgb = torch.from_numpy(curr_rgb)
+ # rel_pose = torch.from_numpy(tmpl_annos['tmpl_pose_list'][0])
+
+ data = dict(input=rgbs[0],
+ target=depth_out,
+ intrinsic=curr_intrinsic_mat,
+ filename=filename,
+ dataset=self.data_name,
+ cam_model=cam_models_stacks,
+ # ref_input=rgbs[1:],
+ # tmpl_flg=tmpl_annos['w_tmpl'],
+ pad=pad,
+ scale=scale_ratio,
+ raw_rgb=raw_rgb,
+ normal = np.zeros_like(curr_rgb.transpose((2,0,1))),
+ # rel_pose=rel_pose,
+ )
+ return data
+
+ def process_depth(self, depth, rgb):
+ new_depth = np.zeros_like(depth)
+ H, W = depth.shape
+ crop_h_up = int(0.3324324 * H)
+ crop_h_down = int(0.91351351 * H)
+ crop_w_left = int(0.0359477 * W)
+ crop_w_right = int(0.96405229 * W)
+
+ new_depth[crop_h_up:crop_h_down, crop_w_left: crop_w_right] = depth[crop_h_up:crop_h_down, crop_w_left: crop_w_right]
+ new_depth[new_depth>65500] = 0
+ new_depth /= self.metric_scale
+ #print('image size', new_depth.shape, crop_h_up, crop_h_down, crop_w_left, crop_w_right)
+ #self.logger.info('image size, {new_depth.shape}, {crop_h_up}, {crop_h_down}, {crop_w_left}, {crop_w_right}')
+ return new_depth
+
+
+
+if __name__ == '__main__':
+ from mmcv.utils import Config
+ cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
+ dataset_i = KITTIDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
+ print(dataset_i)
+
\ No newline at end of file
diff --git a/training/mono/datasets/lyft_dataset.py b/training/mono/datasets/lyft_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e007d100917b0023291a07c9fb4a9427244c7cbe
--- /dev/null
+++ b/training/mono/datasets/lyft_dataset.py
@@ -0,0 +1,34 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+import pickle
+
+class LyftDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(LyftDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+
+
+ def process_depth(self, depth, rgb):
+ depth[depth>65500] = 0
+ depth /= self.metric_scale
+ return depth
+
+
+
+if __name__ == '__main__':
+ from mmcv.utils import Config
+ cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
+ dataset_i = ApolloscapeDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
+ print(dataset_i)
+
\ No newline at end of file
diff --git a/training/mono/datasets/mapillary_psd_dataset.py b/training/mono/datasets/mapillary_psd_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f10899c7c7362d005ffeacf79aa2ce288fff1d4
--- /dev/null
+++ b/training/mono/datasets/mapillary_psd_dataset.py
@@ -0,0 +1,35 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+import matplotlib.pyplot as plt
+
+class MapillaryPSDDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(MapillaryPSDDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+
+ def process_depth(self, depth, rgb):
+ depth[depth>65500] = 0
+ depth /= self.metric_scale
+ h, w, _ = rgb.shape # to rgb size
+ depth_resize = cv2.resize(depth, (w, h), interpolation=cv2.INTER_NEAREST)
+ return depth_resize
+
+
+
+if __name__ == '__main__':
+ from mmcv.utils import Config
+ cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
+ dataset_i = MapillaryDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
+ print(dataset_i)
+
\ No newline at end of file
diff --git a/training/mono/datasets/matterport3d_dataset.py b/training/mono/datasets/matterport3d_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..27afefb74ed699ed9756f3ce3c3c5530a3dcb94b
--- /dev/null
+++ b/training/mono/datasets/matterport3d_dataset.py
@@ -0,0 +1,44 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from PIL import Image
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+
+
+class Matterport3DDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(Matterport3DDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+ #self.cap_range = self.depth_range # in meter
+
+ def load_norm_label(self, norm_path, H, W):
+ normal_x = cv2.imread(norm_path['x'], cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
+ normal_y = cv2.imread(norm_path['y'], cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
+ normal_z = cv2.imread(norm_path['z'], cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
+ raw_normal = np.array([normal_x, normal_y, normal_z])
+ invalid_mask = np.all(raw_normal == 0, axis=0)
+
+ ego_normal = raw_normal.astype(np.float64) / 32768.0 - 1
+ ego2cam = np.array([[1,0,0],
+ [0,-1,0],
+ [0,0,-1]])
+ normal = (ego2cam @ ego_normal.reshape(3,-1)).reshape(ego_normal.shape)
+ normal[:,invalid_mask] = 0
+ normal = normal.transpose((1,2,0))
+ if normal.shape[0] != H or normal.shape[1] != W:
+ normal = cv2.resize(normal, [W,H], interpolation=cv2.INTER_NEAREST)
+ return normal
+
+ def process_depth(self, depth: np.array, rgb: np.array) -> np.array:
+ depth[depth>65500] = 0
+ depth = depth / self.metric_scale
+ return depth
diff --git a/training/mono/datasets/nuscenes_dataset.py b/training/mono/datasets/nuscenes_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..b72e1c073bad9aeb8b237a1d8c33eec55fefe0be
--- /dev/null
+++ b/training/mono/datasets/nuscenes_dataset.py
@@ -0,0 +1,34 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+import pickle
+
+class NuScenesDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(NuScenesDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+
+
+ def process_depth(self, depth, rgb):
+ depth[depth>65500] = 0
+ depth /= self.metric_scale
+ return depth
+
+
+
+if __name__ == '__main__':
+ from mmcv.utils import Config
+ cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
+ dataset_i = ApolloscapeDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
+ print(dataset_i)
+
\ No newline at end of file
diff --git a/training/mono/datasets/nyu_dataset.py b/training/mono/datasets/nyu_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ad8a1a0b84a1abd1dd055f385db2322fa8d0cb9
--- /dev/null
+++ b/training/mono/datasets/nyu_dataset.py
@@ -0,0 +1,195 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+
+
+class NYUDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(NYUDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+
+ def get_data_for_trainval(self, idx: int):
+ anno = self.annotations['files'][idx]
+ meta_data = self.load_meta_data(anno)
+
+ data_path = self.load_data_path(meta_data)
+ data_batch = self.load_batch(meta_data, data_path)
+ # if data_path['sem_path'] is not None:
+ # print(self.data_name)
+
+ curr_rgb, curr_depth, curr_normal, curr_sem, curr_cam_model = data_batch['curr_rgb'], data_batch['curr_depth'], data_batch['curr_normal'], data_batch['curr_sem'], data_batch['curr_cam_model']
+ #curr_stereo_depth = data_batch['curr_stereo_depth']
+ new_rgb = np.zeros_like(curr_rgb)
+ new_rgb[6:-6, 6:-6, :] = curr_rgb[6:-6, 6:-6, :]
+ curr_rgb = new_rgb
+
+ # A patch for stereo depth dataloader (no need to modify specific datasets)
+ if 'curr_stereo_depth' in data_batch.keys():
+ curr_stereo_depth = data_batch['curr_stereo_depth']
+ else:
+ curr_stereo_depth = self.load_stereo_depth_label(None, H=curr_rgb.shape[0], W=curr_rgb.shape[1])
+
+ curr_intrinsic = meta_data['cam_in']
+ # data augmentation
+ transform_paras = dict(random_crop_size = self.random_crop_size) # dict()
+ assert curr_rgb.shape[:2] == curr_depth.shape == curr_normal.shape[:2] == curr_sem.shape
+ rgbs, depths, intrinsics, cam_models, normals, other_labels, transform_paras = self.img_transforms(
+ images=[curr_rgb, ],
+ labels=[curr_depth, ],
+ intrinsics=[curr_intrinsic,],
+ cam_models=[curr_cam_model, ],
+ normals = [curr_normal, ],
+ other_labels=[curr_sem, curr_stereo_depth],
+ transform_paras=transform_paras)
+ # process sky masks
+ sem_mask = other_labels[0].int()
+ # clip depth map
+ depth_out = self.normalize_depth(depths[0])
+ # set the depth of sky region to the invalid
+ depth_out[sem_mask==142] = -1 # self.depth_normalize[1] - 1e-6
+ # get inverse depth
+ inv_depth = self.depth2invdepth(depth_out, sem_mask==142)
+ filename = os.path.basename(meta_data['rgb'])[:-4] + '.jpg'
+ curr_intrinsic_mat = self.intrinsics_list2mat(intrinsics[0])
+ cam_models_stacks = [
+ torch.nn.functional.interpolate(cam_models[0][None, :, :, :], size=(cam_models[0].shape[1]//i, cam_models[0].shape[2]//i), mode='bilinear', align_corners=False).squeeze()
+ for i in [2, 4, 8, 16, 32]
+ ]
+
+ # stereo_depth
+ stereo_depth_pre_trans = other_labels[1] * (other_labels[1] > 0.3) * (other_labels[1] < 200)
+ stereo_depth = stereo_depth_pre_trans * transform_paras['label_scale_factor']
+ stereo_depth = self.normalize_depth(stereo_depth)
+
+ pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
+ data = dict(input=rgbs[0],
+ target=depth_out,
+ intrinsic=curr_intrinsic_mat,
+ filename=filename,
+ dataset=self.data_name,
+ cam_model=cam_models_stacks,
+ pad=torch.tensor(pad),
+ data_type=[self.data_type, ],
+ sem_mask=sem_mask.int(),
+ stereo_depth= stereo_depth,
+ normal=normals[0],
+ inv_depth=inv_depth,
+ scale=transform_paras['label_scale_factor'])
+ return data
+
+ def get_data_for_test(self, idx: int):
+ anno = self.annotations['files'][idx]
+ meta_data = self.load_meta_data(anno)
+ curr_rgb_path = os.path.join(self.data_root, meta_data['rgb'])
+ curr_depth_path = os.path.join(self.depth_root, meta_data['depth'])
+ # load data
+ ori_curr_intrinsic = meta_data['cam_in']
+ curr_rgb, curr_depth = self.load_rgb_depth(curr_rgb_path, curr_depth_path)
+ # crop rgb/depth
+ new_rgb = np.zeros_like(curr_rgb)
+ new_rgb[6:-6, 6:-6, :] = curr_rgb[6:-6, 6:-6, :]
+ curr_rgb = new_rgb
+
+ ori_h, ori_w, _ = curr_rgb.shape
+ # create camera model
+ curr_cam_model = self.create_cam_model(curr_rgb.shape[0], curr_rgb.shape[1], ori_curr_intrinsic)
+
+ if 'normal' in meta_data.keys():
+ normal_path = os.path.join(self.data_root, meta_data['normal'])
+ else:
+ normal_path = None
+
+ curr_normal = self.load_norm_label(normal_path, H=curr_rgb.shape[0], W=curr_rgb.shape[1])
+ # load tmpl rgb info
+ # tmpl_annos = self.load_tmpl_image_pose(curr_rgb, meta_data)
+ # tmpl_rgbs = tmpl_annos['tmpl_rgb_list'] # list of reference rgbs
+
+ # get crop size
+ transform_paras = dict()
+ rgbs, depths, intrinsics, cam_models, normals, other_labels, transform_paras = self.img_transforms(
+ images=[curr_rgb,], #+ tmpl_rgbs,
+ labels=[curr_depth, ],
+ intrinsics=[ori_curr_intrinsic, ], # * (len(tmpl_rgbs) + 1),
+ cam_models=[curr_cam_model, ],
+ normals = [curr_normal, ],
+ transform_paras=transform_paras)
+ # depth in original size and orignial metric***
+ depth_out = self.clip_depth(curr_depth) * self.depth_range[1] # self.clip_depth(depths[0]) #
+
+ filename = os.path.basename(meta_data['rgb'])
+ curr_intrinsic_mat = self.intrinsics_list2mat(intrinsics[0])
+
+ pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
+ scale_ratio = transform_paras['label_scale_factor'] if 'label_scale_factor' in transform_paras else 1.0
+ cam_models_stacks = [
+ torch.nn.functional.interpolate(cam_models[0][None, :, :, :], size=(cam_models[0].shape[1]//i, cam_models[0].shape[2]//i), mode='bilinear', align_corners=False).squeeze()
+ for i in [2, 4, 8, 16, 32]
+ ]
+ raw_rgb = torch.from_numpy(curr_rgb)
+ # rel_pose = torch.from_numpy(tmpl_annos['tmpl_pose_list'][0])
+ curr_normal = torch.from_numpy(curr_normal.transpose((2,0,1)))
+
+ data = dict(input=rgbs[0],
+ target=depth_out,
+ intrinsic=curr_intrinsic_mat,
+ filename=filename,
+ dataset=self.data_name,
+ cam_model=cam_models_stacks,
+ # ref_input=rgbs[1:],
+ # tmpl_flg=tmpl_annos['w_tmpl'],
+ pad=pad,
+ scale=scale_ratio,
+ raw_rgb=raw_rgb,
+ # rel_pose=rel_pose,
+ normal=curr_normal
+ #normal=np.zeros_like(curr_rgb.transpose((2,0,1))),
+ )
+ return data
+
+ def load_norm_label(self, norm_path, H, W):
+ if norm_path is None:
+ norm_gt = np.zeros((H, W, 3)).astype(np.float32)
+ else:
+ norm_gt = cv2.imread(norm_path)
+
+ norm_gt = np.array(norm_gt).astype(np.uint8)
+ norm_valid_mask = np.logical_not(
+ np.logical_and(
+ np.logical_and(
+ norm_gt[:, :, 0] == 0, norm_gt[:, :, 1] == 0),
+ norm_gt[:, :, 2] == 0))
+ norm_valid_mask = norm_valid_mask[:, :, np.newaxis]
+
+ norm_gt = ((norm_gt.astype(np.float32) / 255.0) * 2.0) - 1.0
+ norm_gt = norm_gt * norm_valid_mask * -1
+
+ return norm_gt
+
+ def process_depth(self, depth, rgb):
+ # eign crop
+ new_depth = np.zeros_like(depth)
+ new_depth[45:471, 41:601] = depth[45:471, 41:601]
+
+ new_depth[new_depth>65500] = 0
+ new_depth /= self.metric_scale
+ return new_depth
+
+
+
+
+if __name__ == '__main__':
+ from mmcv.utils import Config
+ cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
+ dataset_i = NYUDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
+ print(dataset_i)
+
\ No newline at end of file
diff --git a/training/mono/datasets/pandaset_dataset.py b/training/mono/datasets/pandaset_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6defd829967d97650f6131764da275828579a25
--- /dev/null
+++ b/training/mono/datasets/pandaset_dataset.py
@@ -0,0 +1,36 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+
+
+class PandasetDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(PandasetDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+
+
+ def process_depth(self, depth, rgb):
+ depth[depth>65500] = 0
+ depth /= self.metric_scale
+ # depth[(depth>self.cap_range[1]) | (depth np.array:
+ depth[depth>60000] = 0
+ depth = depth / self.metric_scale
+ return depth
diff --git a/training/mono/datasets/scannet_dataset.py b/training/mono/datasets/scannet_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e51cd3fa686656c2d1c5d535358d0d9c471a35d
--- /dev/null
+++ b/training/mono/datasets/scannet_dataset.py
@@ -0,0 +1,295 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+
+
+class ScanNetDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(ScanNetDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+
+ # def get_data_for_test(self, idx):
+ # anno = self.annotations['files'][idx]
+ # curr_rgb_path = os.path.join(self.data_root, anno['rgb'])
+ # curr_depth_path = os.path.join(self.depth_root, anno['depth'])
+ # meta_data = self.load_meta_data(anno)
+ # ori_curr_intrinsic = meta_data['cam_in']
+
+ # curr_rgb, curr_depth = self.load_rgb_depth(curr_rgb_path, curr_depth_path)
+ # # curr_rgb = cv2.resize(curr_rgb, dsize=(640, 480), interpolation=cv2.INTER_LINEAR)
+ # ori_h, ori_w, _ = curr_rgb.shape
+ # # create camera model
+ # curr_cam_model = self.create_cam_model(curr_rgb.shape[0], curr_rgb.shape[1], ori_curr_intrinsic)
+ # # load tmpl rgb info
+ # # tmpl_annos = self.load_tmpl_annos(anno, curr_rgb, meta_data)
+ # # tmpl_rgb = tmpl_annos['tmpl_rgb_list'] # list of reference rgbs
+
+ # transform_paras = dict()
+ # rgbs, depths, intrinsics, cam_models, _, other_labels, transform_paras = self.img_transforms(
+ # images=[curr_rgb, ],
+ # labels=[curr_depth, ],
+ # intrinsics=[ori_curr_intrinsic,],
+ # cam_models=[curr_cam_model, ],
+ # transform_paras=transform_paras)
+ # # depth in original size
+ # depth_out = self.clip_depth(curr_depth) * self.depth_range[1]
+
+ # filename = os.path.basename(anno['rgb'])
+ # curr_intrinsic_mat = self.intrinsics_list2mat(intrinsics[0])
+
+ # pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
+ # scale_ratio = transform_paras['label_scale_factor'] if 'label_scale_factor' in transform_paras else 1.0
+ # cam_models_stacks = [
+ # torch.nn.functional.interpolate(cam_models[0][None, :, :, :], size=(cam_models[0].shape[1]//i, cam_models[0].shape[2]//i), mode='bilinear', align_corners=False).squeeze()
+ # for i in [2, 4, 8, 16, 32]
+ # ]
+ # raw_rgb = torch.from_numpy(curr_rgb)
+ # data = dict(input=rgbs[0],
+ # target=depth_out,
+ # intrinsic=curr_intrinsic_mat,
+ # filename=filename,
+ # dataset=self.data_name,
+ # cam_model=cam_models_stacks,
+ # ref_input=rgbs[1:],
+ # tmpl_flg=False,
+ # pad=pad,
+ # scale=scale_ratio,
+ # raw_rgb=raw_rgb,
+ # normal =np.zeros_like(curr_rgb.transpose((2,0,1))),
+ # )
+ # return data
+
+ def get_data_for_test(self, idx: int, test_mode=True):
+ anno = self.annotations['files'][idx]
+ meta_data = self.load_meta_data(anno)
+ data_path = self.load_data_path(meta_data)
+ data_batch = self.load_batch(meta_data, data_path, test_mode)
+ # load data
+ curr_rgb, curr_depth, curr_normal, curr_cam_model = data_batch['curr_rgb'], data_batch['curr_depth'], data_batch['curr_normal'], data_batch['curr_cam_model']
+ ori_curr_intrinsic = meta_data['cam_in']
+
+ # get crop size
+ transform_paras = dict()
+ rgbs, depths, intrinsics, cam_models, _, other_labels, transform_paras = self.img_transforms(
+ images=[curr_rgb,], #+ tmpl_rgbs,
+ labels=[curr_depth, ],
+ intrinsics=[ori_curr_intrinsic, ], # * (len(tmpl_rgbs) + 1),
+ cam_models=[curr_cam_model, ],
+ transform_paras=transform_paras)
+ # depth in original size and orignial metric***
+ depth_out = self.clip_depth(curr_depth) * self.depth_range[1] # self.clip_depth(depths[0]) #
+ inv_depth = self.depth2invdepth(depth_out, np.zeros_like(depth_out, dtype=np.bool))
+ filename = os.path.basename(meta_data['rgb'])[:-4] + '.jpg'
+ curr_intrinsic_mat = self.intrinsics_list2mat(intrinsics[0])
+
+ pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
+ scale_ratio = transform_paras['label_scale_factor'] if 'label_scale_factor' in transform_paras else 1.0
+ cam_models_stacks = [
+ torch.nn.functional.interpolate(cam_models[0][None, :, :, :], size=(cam_models[0].shape[1]//i, cam_models[0].shape[2]//i), mode='bilinear', align_corners=False).squeeze()
+ for i in [2, 4, 8, 16, 32]
+ ]
+ raw_rgb = torch.from_numpy(curr_rgb)
+ curr_normal = torch.from_numpy(curr_normal.transpose((2,0,1)))
+
+
+ data = dict(input=rgbs[0],
+ target=depth_out,
+ intrinsic=curr_intrinsic_mat,
+ filename=filename,
+ dataset=self.data_name,
+ cam_model=cam_models_stacks,
+ pad=pad,
+ scale=scale_ratio,
+ raw_rgb=raw_rgb,
+ sample_id=idx,
+ data_path=meta_data['rgb'],
+ inv_depth=inv_depth,
+ normal=curr_normal,
+ )
+ return data
+
+ def get_data_for_trainval(self, idx: int):
+ anno = self.annotations['files'][idx]
+ meta_data = self.load_meta_data(anno)
+
+ data_path = self.load_data_path(meta_data)
+ data_batch = self.load_batch(meta_data, data_path, test_mode=False)
+
+ # if data_path['sem_path'] is not None:
+ # print(self.data_name)
+
+ curr_rgb, curr_depth, curr_normal, curr_sem, curr_cam_model = data_batch['curr_rgb'], data_batch['curr_depth'], data_batch['curr_normal'], data_batch['curr_sem'], data_batch['curr_cam_model']
+ #curr_stereo_depth = data_batch['curr_stereo_depth']
+
+ # A patch for stereo depth dataloader (no need to modify specific datasets)
+ if 'curr_stereo_depth' in data_batch.keys():
+ curr_stereo_depth = data_batch['curr_stereo_depth']
+ else:
+ curr_stereo_depth = self.load_stereo_depth_label(None, H=curr_rgb.shape[0], W=curr_rgb.shape[1])
+
+ curr_intrinsic = meta_data['cam_in']
+ # data augmentation
+ transform_paras = dict(random_crop_size = self.random_crop_size) # dict()
+ assert curr_rgb.shape[:2] == curr_depth.shape == curr_normal.shape[:2] == curr_sem.shape
+ rgbs, depths, intrinsics, cam_models, normals, other_labels, transform_paras = self.img_transforms(
+ images=[curr_rgb, ],
+ labels=[curr_depth, ],
+ intrinsics=[curr_intrinsic,],
+ cam_models=[curr_cam_model, ],
+ normals = [curr_normal, ],
+ other_labels=[curr_sem, curr_stereo_depth],
+ transform_paras=transform_paras)
+ # process sky masks
+ sem_mask = other_labels[0].int()
+ # clip depth map
+ depth_out = self.normalize_depth(depths[0])
+ # set the depth of sky region to the invalid
+ depth_out[sem_mask==142] = -1 # self.depth_normalize[1] - 1e-6
+ # get inverse depth
+ inv_depth = self.depth2invdepth(depth_out, sem_mask==142)
+ filename = os.path.basename(meta_data['rgb'])[:-4] + '.jpg'
+ curr_intrinsic_mat = self.intrinsics_list2mat(intrinsics[0])
+ cam_models_stacks = [
+ torch.nn.functional.interpolate(cam_models[0][None, :, :, :], size=(cam_models[0].shape[1]//i, cam_models[0].shape[2]//i), mode='bilinear', align_corners=False).squeeze()
+ for i in [2, 4, 8, 16, 32]
+ ]
+
+ # stereo_depth
+ stereo_depth_pre_trans = other_labels[1] * (other_labels[1] > 0.3) * (other_labels[1] < 200)
+ stereo_depth = stereo_depth_pre_trans * transform_paras['label_scale_factor']
+ stereo_depth = self.normalize_depth(stereo_depth)
+
+ pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
+ data = dict(input=rgbs[0],
+ target=depth_out,
+ intrinsic=curr_intrinsic_mat,
+ filename=filename,
+ dataset=self.data_name,
+ cam_model=cam_models_stacks,
+ pad=torch.tensor(pad),
+ data_type=[self.data_type, ],
+ sem_mask=sem_mask.int(),
+ stereo_depth= stereo_depth,
+ normal=normals[0],
+ inv_depth=inv_depth,
+ scale=transform_paras['label_scale_factor'])
+ return data
+
+ def load_batch(self, meta_data, data_path, test_mode):
+
+ # print('############')
+ # print(data_path['rgb_path'])
+ # print(data_path['normal_path'])
+ # print('############')
+
+ curr_intrinsic = meta_data['cam_in']
+ # load rgb/depth
+ curr_rgb, curr_depth = self.load_rgb_depth(data_path['rgb_path'], data_path['depth_path'], test_mode)
+ # get semantic labels
+ curr_sem = self.load_sem_label(data_path['sem_path'], curr_depth)
+ # create camera model
+ curr_cam_model = self.create_cam_model(curr_rgb.shape[0], curr_rgb.shape[1], curr_intrinsic)
+ # get normal labels
+ curr_normal = self.load_norm_label(data_path['normal_path'], H=curr_rgb.shape[0], W=curr_rgb.shape[1], test_mode=test_mode)
+ # get depth mask
+ depth_mask = self.load_depth_valid_mask(data_path['depth_mask_path'])
+ curr_depth[~depth_mask] = -1
+ # get stereo depth
+ curr_stereo_depth = self.load_stereo_depth_label(data_path['disp_path'], H=curr_rgb.shape[0], W=curr_rgb.shape[1])
+
+ data_batch = dict(
+ curr_rgb = curr_rgb,
+ curr_depth = curr_depth,
+ curr_sem = curr_sem,
+ curr_normal = curr_normal,
+ curr_cam_model=curr_cam_model,
+ curr_stereo_depth=curr_stereo_depth,
+ )
+ return data_batch
+
+ def load_rgb_depth(self, rgb_path: str, depth_path: str, test_mode: bool):
+ """
+ Load the rgb and depth map with the paths.
+ """
+ rgb = self.load_data(rgb_path, is_rgb_img=True)
+ if rgb is None:
+ self.logger.info(f'>>>>{rgb_path} has errors.')
+
+ depth = self.load_data(depth_path)
+ if depth is None:
+ self.logger.info(f'{depth_path} has errors.')
+
+ # self.check_data(dict(
+ # rgb_path=rgb,
+ # depth_path=depth,
+ # ))
+ depth = depth.astype(np.float)
+ # if depth.shape != rgb.shape[:2]:
+ # print(f'no-equal in {self.data_name}')
+ # depth = cv2.resize(depth, rgb.shape[::-1][1:])
+
+ depth = self.process_depth(depth, rgb, test_mode)
+ return rgb, depth
+
+ def process_depth(self, depth, rgb, test_mode=False):
+ depth[depth>65500] = 0
+ depth /= self.metric_scale
+ h, w, _ = rgb.shape # to rgb size
+ if test_mode==False:
+ depth = cv2.resize(depth, (w, h), interpolation=cv2.INTER_NEAREST)
+ return depth
+
+ def load_norm_label(self, norm_path, H, W, test_mode):
+
+ if norm_path is None:
+ norm_gt = np.zeros((H, W, 3)).astype(np.float32)
+ else:
+ norm_gt = cv2.imread(norm_path)
+ norm_gt = cv2.cvtColor(norm_gt, cv2.COLOR_BGR2RGB)
+
+ norm_gt = np.array(norm_gt).astype(np.uint8)
+
+ mask_path = 'orient-mask'.join(norm_path.rsplit('normal', 1))
+ mask_gt = cv2.imread(mask_path)
+ mask_gt = np.array(mask_gt).astype(np.uint8)
+ valid_mask = np.logical_not(
+ np.logical_and(
+ np.logical_and(
+ mask_gt[:, :, 0] == 0, mask_gt[:, :, 1] == 0),
+ mask_gt[:, :, 2] == 0))
+ valid_mask = valid_mask[:, :, np.newaxis]
+
+ # norm_valid_mask = np.logical_not(
+ # np.logical_and(
+ # np.logical_and(
+ # norm_gt[:, :, 0] == 0, norm_gt[:, :, 1] == 0),
+ # norm_gt[:, :, 2] == 0))
+ # norm_valid_mask = norm_valid_mask[:, :, np.newaxis]
+
+ norm_gt = ((norm_gt.astype(np.float32) / 255.0) * 2.0) - 1.0
+ norm_valid_mask = (np.linalg.norm(norm_gt, axis=2, keepdims=True) > 0.5) * valid_mask
+ norm_gt = norm_gt * norm_valid_mask
+
+ if test_mode==False:
+ norm_gt = cv2.resize(norm_gt, (W, H), interpolation=cv2.INTER_NEAREST)
+
+ return norm_gt
+
+
+
+if __name__ == '__main__':
+ from mmcv.utils import Config
+ cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
+ dataset_i = NYUDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
+ print(dataset_i)
+
\ No newline at end of file
diff --git a/training/mono/datasets/taskonomy_dataset.py b/training/mono/datasets/taskonomy_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f0982108f45fc77a72dcd05287841de5c89a9d4
--- /dev/null
+++ b/training/mono/datasets/taskonomy_dataset.py
@@ -0,0 +1,190 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from PIL import Image
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+import pickle
+
+
+class TaskonomyDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(TaskonomyDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+ #self.cap_range = self.depth_range # in meter
+
+ def __getitem__(self, idx: int) -> dict:
+ if self.phase == 'test':
+ return self.get_data_for_test(idx)
+ else:
+ return self.get_data_for_trainval(idx)
+
+ def load_meta_data(self, anno: dict) -> dict:
+ """
+ Load meta data information.
+ """
+ if self.meta_data_root is not None and ('meta_data' in anno or 'meta' in anno):
+ meta_data_path = os.path.join(self.meta_data_root, anno['meta_data']) if 'meta_data' in anno else os.path.join(self.meta_data_root, anno['meta'])
+ with open(meta_data_path, 'rb') as f:
+ meta_data = pickle.load(f)
+ meta_data.update(anno)
+ else:
+ meta_data = anno
+ u0, v0, fx, fy = meta_data['cam_in']
+ meta_data['cam_in'] = [fx, fy, u0, v0] # fix data bugs
+ return meta_data
+
+ def get_data_for_trainval(self, idx: int):
+ anno = self.annotations['files'][idx]
+ meta_data = self.load_meta_data(anno)
+
+ data_path = self.load_data_path(meta_data)
+ data_batch = self.load_batch(meta_data, data_path)
+ curr_rgb, curr_depth, curr_normal, curr_cam_model = data_batch['curr_rgb'], data_batch['curr_depth'], data_batch['curr_normal'], data_batch['curr_cam_model']
+ curr_intrinsic = meta_data['cam_in']
+
+ ins_planes_path = os.path.join(self.data_root, meta_data['ins_planes']) if ('ins_planes' in meta_data) and (meta_data['ins_planes'] is not None) else None
+ # get instance planes
+ ins_planes = self.load_ins_planes(curr_depth, ins_planes_path)
+
+ # load data
+ # u0, v0, fx, fy = meta_data['cam_in'] # this is
+ # ori_curr_intrinsic = [fx, fy, u0, v0]
+ # curr_rgb, curr_depth = self.load_rgb_depth(curr_rgb_path, curr_depth_path)
+
+ # get crop size
+ # transform_paras = dict()
+ transform_paras = dict(random_crop_size = self.random_crop_size)
+ rgbs, depths, intrinsics, cam_models, normals, other_labels, transform_paras = self.img_transforms(
+ images=[curr_rgb, ],
+ labels=[curr_depth, ],
+ intrinsics=[curr_intrinsic,],
+ cam_models=[curr_cam_model, ],
+ normals = [curr_normal, ],
+ other_labels=[ins_planes, ],
+ transform_paras=transform_paras)
+ # process instance planes
+ ins_planes = other_labels[0].int()
+
+ # clip depth map
+ depth_out = self.normalize_depth(depths[0])
+ # get inverse depth
+ inv_depth = self.depth2invdepth(depth_out, torch.zeros_like(depth_out, dtype=torch.bool))
+ filename = os.path.basename(meta_data['rgb'])
+ curr_intrinsic_mat = self.intrinsics_list2mat(intrinsics[0])
+ cam_models_stacks = [
+ torch.nn.functional.interpolate(cam_models[0][None, :, :, :], size=(cam_models[0].shape[1]//i, cam_models[0].shape[2]//i), mode='bilinear', align_corners=False).squeeze()
+ for i in [2, 4, 8, 16, 32]
+ ]
+ pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
+ data = dict(input=rgbs[0],
+ target=depth_out,
+ intrinsic=curr_intrinsic_mat,
+ filename=filename,
+ dataset=self.data_name,
+ cam_model=cam_models_stacks,
+ pad=torch.tensor(pad),
+ data_type=[self.data_type, ],
+ sem_mask=ins_planes,
+ normal=normals[0],
+ inv_depth=inv_depth,
+ stereo_depth=torch.zeros_like(inv_depth),
+ scale= transform_paras['label_scale_factor'])
+ return data
+
+ def get_data_for_test(self, idx: int):
+ anno = self.annotations['files'][idx]
+ meta_data = self.load_meta_data(anno)
+ data_path = self.load_data_path(meta_data)
+ data_batch = self.load_batch(meta_data, data_path)
+
+ curr_rgb, curr_depth, curr_normal, curr_cam_model = data_batch['curr_rgb'], data_batch['curr_depth'], data_batch['curr_normal'], data_batch['curr_cam_model']
+ ori_curr_intrinsic = meta_data['cam_in']
+
+ # curr_rgb_path = os.path.join(self.data_root, meta_data['rgb'])
+ # curr_depth_path = os.path.join(self.depth_root, meta_data['depth'])
+
+ # curr_rgb, curr_depth = self.load_rgb_depth(curr_rgb_path, curr_depth_path)
+ # ori_h, ori_w, _ = curr_rgb.shape
+ # # create camera model
+ # curr_cam_model = self.create_cam_model(curr_rgb.shape[0], curr_rgb.shape[1], ori_curr_intrinsic)
+ # load tmpl rgb info
+ # tmpl_annos = self.load_tmpl_image_pose(curr_rgb, meta_data)
+ # tmpl_rgbs = tmpl_annos['tmpl_rgb_list'] # list of reference rgbs
+
+ transform_paras = dict()
+ rgbs, depths, intrinsics, cam_models, _, other_labels, transform_paras = self.img_transforms(
+ images=[curr_rgb,], # + tmpl_rgbs,
+ labels=[curr_depth, ],
+ intrinsics=[ori_curr_intrinsic, ], # * (len(tmpl_rgbs) + 1),
+ cam_models=[curr_cam_model, ],
+ transform_paras=transform_paras)
+ # depth in original size and orignial metric***
+ depth_out = self.clip_depth(curr_depth) * self.depth_range[1]
+ inv_depth = self.depth2invdepth(depth_out, np.zeros_like(depth_out, dtype=np.bool))
+
+ filename = os.path.basename(meta_data['rgb'])
+ curr_intrinsic_mat = self.intrinsics_list2mat(intrinsics[0])
+
+ pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
+ scale_ratio = transform_paras['label_scale_factor'] if 'label_scale_factor' in transform_paras else 1.0
+ cam_models_stacks = [
+ torch.nn.functional.interpolate(cam_models[0][None, :, :, :], size=(cam_models[0].shape[1]//i, cam_models[0].shape[2]//i), mode='bilinear', align_corners=False).squeeze()
+ for i in [2, 4, 8, 16, 32]
+ ]
+ raw_rgb = torch.from_numpy(curr_rgb)
+ curr_normal = torch.from_numpy(curr_normal.transpose((2,0,1)))
+
+ data = dict(input=rgbs[0],
+ target=depth_out,
+ intrinsic=curr_intrinsic_mat,
+ filename=filename,
+ dataset=self.data_name,
+ cam_model=cam_models_stacks,
+ pad=pad,
+ scale=scale_ratio,
+ raw_rgb=raw_rgb,
+ sample_id=idx,
+ data_path=meta_data['rgb'],
+ inv_depth=inv_depth,
+ normal=curr_normal,
+ )
+ return data
+
+ def load_norm_label(self, norm_path, H, W):
+ with open(norm_path, 'rb') as f:
+ normal = Image.open(f)
+ normal = np.array(normal.convert(normal.mode), dtype=np.uint8)
+ invalid_mask = np.all(normal == 128, axis=2)
+ normal = normal.astype(np.float64) / 255.0 * 2 - 1
+ normal[invalid_mask, :] = 0
+ return normal
+
+ def process_depth(self, depth: np.array, rgb: np.array) -> np.array:
+ depth[depth>60000] = 0
+ depth = depth / self.metric_scale
+ return depth
+
+ def load_ins_planes(self, depth: np.array, ins_planes_path: str) -> np.array:
+ if ins_planes_path is not None:
+ ins_planes = cv2.imread(ins_planes_path, -1)
+ else:
+ ins_planes = np.zeros_like(depth)
+ return ins_planes
+
+
+
+if __name__ == '__main__':
+ from mmcv.utils import Config
+ cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
+ dataset_i = ApolloscapeDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
+ print(dataset_i)
+
\ No newline at end of file
diff --git a/training/mono/datasets/uasol_dataset.py b/training/mono/datasets/uasol_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1ccab0240a747bbf070f02d4da60f7703975dd3
--- /dev/null
+++ b/training/mono/datasets/uasol_dataset.py
@@ -0,0 +1,52 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+
+
+class UASOLDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(UASOLDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+
+
+ def process_depth(self, depth, rgb):
+ depth[depth>65500] = 0
+ depth /= self.metric_scale
+ return depth
+
+ def load_rgb_depth(self, rgb_path: str, depth_path: str) -> (np.array, np.array):
+ """
+ Load the rgb and depth map with the paths.
+ """
+ rgb = self.load_data(rgb_path, is_rgb_img=True)
+ if rgb is None:
+ self.logger.info(f'>>>>{rgb_path} has errors.')
+
+ depth = self.load_data(depth_path)
+ if depth is None:
+ self.logger.info(f'{depth_path} has errors.')
+
+ depth = depth.astype(np.float)
+
+ depth = self.process_depth(depth, rgb)
+ depth = depth[1:-1, ...]
+ return rgb, depth
+
+
+
+if __name__ == '__main__':
+ from mmcv.utils import Config
+ cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
+ dataset_i = UASOLDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
+ print(dataset_i)
+
\ No newline at end of file
diff --git a/training/mono/datasets/virtualkitti_dataset.py b/training/mono/datasets/virtualkitti_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ac6b76fc97cabfb46d8c814767af27cb062b6d6
--- /dev/null
+++ b/training/mono/datasets/virtualkitti_dataset.py
@@ -0,0 +1,65 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+
+
+class VKITTIDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(VKITTIDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+
+
+
+ def process_depth(self, depth, rgb):
+ depth[depth>(150 * self.metric_scale)] = 0
+ depth /= self.metric_scale
+
+ return depth
+
+ def load_sem_label(self, sem_path, depth=None, sky_id=142) -> np.array:
+ """
+ Category r g b
+ Terrain 210 0 200
+ Sky 90 200 255
+ Tree 0 199 0
+ Vegetation 90 240 0
+ Building 140 140 140
+ Road 100 60 100
+ GuardRail 250 100 255
+ TrafficSign 255 255 0
+ TrafficLight 200 200 0
+ Pole 255 130 0
+ Misc 80 80 80
+ Truck 160 60 60
+ Car 255 127 80
+ Van 0 139 139
+ """
+ H, W = depth.shape
+ sem_label = np.ones((H, W), dtype=np.int) * -1
+ sem = cv2.imread(sem_path)[:, :, ::-1]
+ if sem is None:
+ return sem_label
+
+ sky_color = [90, 200, 255]
+ sky_mask = (sem == sky_color).all(axis=2)
+ sem_label[sky_mask] = 142 # set sky region to 142
+ return sem_label
+
+
+
+if __name__ == '__main__':
+ from mmcv.utils import Config
+ cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
+ dataset_i = ApolloscapeDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
+ print(dataset_i)
+
\ No newline at end of file
diff --git a/training/mono/datasets/waymo_dataset.py b/training/mono/datasets/waymo_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..5611f5a3dbf2bdf266f506f3cbb762dc8d7e1025
--- /dev/null
+++ b/training/mono/datasets/waymo_dataset.py
@@ -0,0 +1,34 @@
+import os
+import json
+import torch
+import torchvision.transforms as transforms
+import os.path
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+import random
+from .__base_dataset__ import BaseDataset
+
+
+class WaymoDataset(BaseDataset):
+ def __init__(self, cfg, phase, **kwargs):
+ super(WaymoDataset, self).__init__(
+ cfg=cfg,
+ phase=phase,
+ **kwargs)
+ self.metric_scale = cfg.metric_scale
+
+
+ def process_depth(self, depth, rgb):
+ depth[depth>65500] = 0
+ depth /= 200.0
+ return depth
+
+
+
+if __name__ == '__main__':
+ from mmcv.utils import Config
+ cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
+ dataset_i = ApolloscapeDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
+ print(dataset_i)
+
\ No newline at end of file
diff --git a/training/mono/model/__base_model__.py b/training/mono/model/__base_model__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0b483c8cc179d2aeac9fafda68ec945123b6229
--- /dev/null
+++ b/training/mono/model/__base_model__.py
@@ -0,0 +1,288 @@
+import torch
+import torch.nn as nn
+from mono.utils.comm import get_func
+import numpy as np
+import torch.nn.functional as F
+
+
+class BaseDepthModel(nn.Module):
+ def __init__(self, cfg, criterions, **kwards):
+ super(BaseDepthModel, self).__init__()
+ model_type = cfg.model.type
+ self.depth_model = get_func('mono.model.model_pipelines.' + model_type)(cfg)
+
+ self.criterions_main = criterions['decoder_losses'] if criterions and 'decoder_losses' in criterions else None
+ self.criterions_auxi = criterions['auxi_losses'] if criterions and 'auxi_losses' in criterions else None
+ self.criterions_pose = criterions['pose_losses'] if criterions and 'pose_losses' in criterions else None
+ self.criterions_gru = criterions['gru_losses'] if criterions and 'gru_losses' in criterions else None
+ try:
+ self.downsample = cfg.prediction_downsample
+ except:
+ self.downsample = None
+
+ self.training = True
+
+ def forward(self, data):
+ if self.downsample != None:
+ self.label_downsample(self.downsample, data)
+
+ output = self.depth_model(**data)
+
+ losses_dict = {}
+ if self.training:
+ output.update(data)
+ losses_dict = self.get_loss(output)
+
+ if self.downsample != None:
+ self.pred_upsample(self.downsample, output)
+
+ return output['prediction'], losses_dict, output['confidence']
+
+ def inference(self, data):
+ with torch.no_grad():
+ output = self.depth_model(**data)
+ output.update(data)
+
+ if self.downsample != None:
+ self.pred_upsample(self.downsample, output)
+
+ output['dataset'] = 'wild'
+ return output
+
+ def get_loss(self, paras):
+ losses_dict = {}
+ # Losses for training
+ if self.training:
+ # decode branch
+ losses_dict.update(self.compute_decoder_loss(paras))
+ # auxilary branch
+ losses_dict.update(self.compute_auxi_loss(paras))
+ # pose branch
+ losses_dict.update(self.compute_pose_loss(paras))
+ # GRU sequence branch
+ losses_dict.update(self.compute_gru_loss(paras))
+
+ total_loss = sum(losses_dict.values())
+ losses_dict['total_loss'] = total_loss
+ return losses_dict
+
+ def compute_gru_loss(self, paras_):
+ losses_dict = {}
+ if self.criterions_gru is None or len(self.criterions_gru) == 0:
+ return losses_dict
+ paras = {k:v for k,v in paras_.items() if k!='prediction' and k!='prediction_normal'}
+ n_predictions = len(paras['predictions_list'])
+ for i, pre in enumerate(paras['predictions_list']):
+ if i == n_predictions-1:
+ break
+ #if i % 3 != 0:
+ #continue
+ if 'normal_out_list' in paras.keys():
+ pre_normal = paras['normal_out_list'][i]
+ else:
+ pre_normal = None
+ iter_dict = self.branch_loss(
+ prediction=pre,
+ prediction_normal=pre_normal,
+ criterions=self.criterions_gru,
+ branch=f'gru_{i}',
+ **paras
+ )
+ # We adjust the loss_gamma so it is consistent for any number of RAFT-Stereo iterations
+ adjusted_loss_gamma = 0.9**(15/(n_predictions - 1))
+ i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
+ iter_dict = {k:v*i_weight for k,v in iter_dict.items()}
+ losses_dict.update(iter_dict)
+ return losses_dict
+
+ def compute_decoder_loss(self, paras):
+ losses_dict = {}
+ decode_losses_dict = self.branch_loss(
+ criterions=self.criterions_main,
+ branch='decode',
+ **paras
+ )
+ return decode_losses_dict
+
+ def compute_auxi_loss(self, paras):
+ losses_dict = {}
+ if len(self.criterions_auxi) == 0:
+ return losses_dict
+ args = dict(
+ target=paras['target'],
+ data_type=paras['data_type'],
+ sem_mask=paras['sem_mask'],
+ )
+ for i, auxi_logit in enumerate(paras['auxi_logit_list']):
+ auxi_losses_dict = self.branch_loss(
+ prediction=paras['auxi_pred'][i],
+ criterions=self.criterions_auxi,
+ pred_logit=auxi_logit,
+ branch=f'auxi_{i}',
+ **args
+ )
+ losses_dict.update(auxi_losses_dict)
+ return losses_dict
+
+ def compute_pose_loss(self, paras):
+ losses_dict = {}
+ if self.criterions_pose is None or len(self.criterions_pose) == 0:
+ return losses_dict
+ # valid_flg = paras['tmpl_flg']
+ # if torch.sum(valid_flg) == 0:
+ # return losses_dict
+ # else:
+ # # sample valid batch
+ # samples = {}
+ # for k, v in paras.items():
+ # if isinstance(v, torch.Tensor):
+ # samples.update({k: v[valid_flg]})
+ # elif isinstance(v, list) and isinstance(v[0], torch.Tensor):
+ # samples.update({k: [i[valid_flg] for i in v]})
+ for loss_method in self.criterions_pose:
+ loss_tmp = loss_method(**paras)
+ losses_dict['pose_' + loss_method._get_name()] = loss_tmp
+ return losses_dict
+
+ def branch_loss(self, prediction, pred_logit, criterions, branch='decode', **kwargs):
+ B, _, _, _ = prediction.shape
+ losses_dict = {}
+ args = dict(pred_logit=pred_logit)
+
+ target = kwargs.pop('target')
+ args.update(kwargs)
+
+ # data type for each batch
+ batches_data_type = np.array(kwargs['data_type'])
+ # batches_data_names = np.array(kwargs['dataset'])
+
+ # resize the target
+ # if target.shape[2] != prediction.shape[2] and target.shape[3] != prediction.shape[3]:
+ # _, _, H, W = prediction.shape
+ # target = nn.functional.interpolate(target, (H,W), mode='nearest')
+
+ mask = target > 1e-8
+ for loss_method in criterions:
+ # sample batches, which satisfy the loss requirement for data types
+ new_mask = self.create_mask_as_loss(loss_method, mask, batches_data_type)
+
+ loss_tmp = loss_method(
+ prediction=prediction,
+ target=target,
+ mask=new_mask,
+ **args)
+ losses_dict[branch + '_' + loss_method._get_name()] = loss_tmp
+ return losses_dict
+
+ def create_mask_as_loss(self, loss_method, mask, batches_data_type):
+ data_type_req = np.array(loss_method.data_type)[:, None]
+ batch_mask = torch.tensor(np.any(data_type_req == batches_data_type, axis=0), device="cuda") #torch.from_numpy(np.any(data_type_req == batches_data_type, axis=0)).cuda()
+ new_mask = mask * batch_mask[:, None, None, None]
+ return new_mask
+
+ def label_downsample(self, downsample_factor, data_dict):
+ scale_factor = float(1.0 / downsample_factor)
+ downsample_target = F.interpolate(data_dict['target'], scale_factor=scale_factor)
+ downsample_stereo_depth = F.interpolate(data_dict['stereo_depth'], scale_factor=scale_factor)
+
+ data_dict['target'] = downsample_target
+ data_dict['stereo_depth'] = downsample_stereo_depth
+
+ return data_dict
+
+ def pred_upsample(self, downsample_factor, data_dict):
+ scale_factor = float(downsample_factor)
+ upsample_prediction = F.interpolate(data_dict['prediction'], scale_factor=scale_factor).detach()
+ upsample_confidence = F.interpolate(data_dict['confidence'], scale_factor=scale_factor).detach()
+
+ data_dict['prediction'] = upsample_prediction
+ data_dict['confidence'] = upsample_confidence
+
+ return data_dict
+
+
+
+
+ # def mask_batches(self, prediction, target, mask, batches_data_names, data_type_req):
+ # """
+ # Mask the data samples that satify the loss requirement.
+ # Args:
+ # data_type_req (str): the data type required by a loss.
+ # batches_data_names (list): the list of data types in a batch.
+ # """
+ # batch_mask = np.any(data_type_req == batches_data_names, axis=0)
+ # prediction = prediction[batch_mask]
+ # target = target[batch_mask]
+ # mask = mask[batch_mask]
+ # return prediction, target, mask, batch_mask
+
+ # def update_mask_g8(self, target, mask, prediction, batches_data_names, absRel=0.5):
+ # data_type_req=np.array(['Golf8_others'])[:, None]
+
+ # pred, target, mask_sample, batch_mask = self.mask_batches(prediction, target, mask, batches_data_names, data_type_req)
+ # if pred.numel() == 0:
+ # return mask
+ # scale_batch = []
+ # for i in range(mask_sample.shape[0]):
+ # scale = torch.median(target[mask_sample]) / (torch.median(pred[mask_sample]) + 1e-8)
+ # abs_rel = torch.abs(pred[i:i+1, ...] * scale - target[i:i+1, ...]) / (pred[i:i+1, ...] * scale + 1e-6)
+ # if target[i, ...][target[i, ...]>0].min() < 0.041:
+ # mask_valid_i = ((abs_rel < absRel) | ((target[i:i+1, ...]<0.02) & (target[i:i+1, ...]>1e-6))) & mask_sample[i:i+1, ...]
+ # else:
+ # mask_valid_i = mask_sample[i:i+1, ...]
+ # mask_sample[i:i+1, ...] = mask_valid_i
+ # # print(target.max(), target[target>0].min())
+ # # self.visual_g8(target, mask_valid_i)
+ # mask[batch_mask] = mask_sample
+ # return mask
+
+ # def update_mask_g8_v2(self, target, mask, prediction, batches_data_names,):
+ # data_type_req=np.array(['Golf8_others'])[:, None]
+
+ # pred, target, mask_sample, batch_mask = self.mask_batches(prediction, target, mask, batches_data_names, data_type_req)
+ # if pred.numel() == 0:
+ # return mask
+
+ # raw_invalid_mask = target < 1e-8
+ # target[raw_invalid_mask] = 1e8
+ # kernal = 31
+ # pool = min_pool2d(target, kernal)
+ # diff = target- pool
+ # valid_mask = (diff < 0.02) & mask_sample & (target<0.3)
+ # target_min = target.view(target.shape[0], -1).min(dim=1)[0]
+ # w_close = target_min < 0.04
+ # valid_mask[~w_close] = mask_sample[~w_close]
+ # mask[batch_mask]= valid_mask
+
+ # target[raw_invalid_mask] = -1
+ # #self.visual_g8(target, mask[batch_mask])
+ # return mask
+
+ # def visual_g8(self, gt, mask):
+ # import matplotlib.pyplot as plt
+ # from mono.utils.transform import gray_to_colormap
+ # gt = gt.cpu().numpy().squeeze()
+ # mask = mask.cpu().numpy().squeeze()
+ # if gt.ndim >2:
+ # gt = gt[0, ...]
+ # mask = mask[0, ...]
+ # name = np.random.randint(1000000)
+ # print(gt.max(), gt[gt>0].min(), name)
+ # gt_filter = gt.copy()
+ # gt_filter[~mask] = 0
+ # out = np.concatenate([gt, gt_filter], axis=0)
+ # out[out<0] = 0
+ # o = gray_to_colormap(out)
+ # o[out<1e-8]=0
+
+ # plt.imsave(f'./tmp/{name}.png', o)
+
+
+
+
+
+def min_pool2d(tensor, kernel, stride=1):
+ tensor = tensor * -1.0
+ tensor = F.max_pool2d(tensor, kernel, padding=kernel//2, stride=stride)
+ tensor = -1.0 * tensor
+ return tensor
\ No newline at end of file
diff --git a/training/mono/model/__init__.py b/training/mono/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca10dcd6c22af3f61832b621c0b05663d629e0b8
--- /dev/null
+++ b/training/mono/model/__init__.py
@@ -0,0 +1,6 @@
+from .monodepth_model import DepthModel
+from .criterion import build_criterions
+from .__base_model__ import BaseDepthModel
+
+
+__all__ = ['DepthModel', 'BaseDepthModel']
diff --git a/training/mono/model/backbones/ConvNeXt.py b/training/mono/model/backbones/ConvNeXt.py
new file mode 100644
index 0000000000000000000000000000000000000000..04d92cdad9c8cbbe9fd448c6c72ecf12e5ec7614
--- /dev/null
+++ b/training/mono/model/backbones/ConvNeXt.py
@@ -0,0 +1,271 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.models.layers import trunc_normal_, DropPath
+from timm.models.registry import register_model
+
+class Block(nn.Module):
+ r""" ConvNeXt Block. There are two equivalent implementations:
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+ We use (2) as we find it slightly faster in PyTorch
+
+ Args:
+ dim (int): Number of input channels.
+ drop_path (float): Stochastic depth rate. Default: 0.0
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ """
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
+ super().__init__()
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
+ self.norm = LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(4 * dim, dim)
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
+ requires_grad=True) if layer_scale_init_value > 0 else None
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x):
+ input = x
+ x = self.dwconv(x)
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ if self.gamma is not None:
+ x = self.gamma * x
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+ x = input + self.drop_path(x)
+ return x
+
+class ConvNeXt(nn.Module):
+ r""" ConvNeXt
+ A PyTorch impl of : `A ConvNet for the 2020s` -
+ https://arxiv.org/pdf/2201.03545.pdf
+ Args:
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
+ """
+ def __init__(self, in_chans=3, num_classes=1000,
+ depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
+ layer_scale_init_value=1e-6, head_init_scale=1.,
+ **kwargs,):
+ super().__init__()
+
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
+ stem = nn.Sequential(
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
+ )
+ self.downsample_layers.append(stem)
+ for i in range(3):
+ downsample_layer = nn.Sequential(
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
+ )
+ self.downsample_layers.append(downsample_layer)
+
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
+ cur = 0
+ for i in range(4):
+ stage = nn.Sequential(
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
+ )
+ self.stages.append(stage)
+ cur += depths[i]
+
+ #self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
+ #self.head = nn.Linear(dims[-1], num_classes)
+
+ self.apply(self._init_weights)
+ #self.head.weight.data.mul_(head_init_scale)
+ #self.head.bias.data.mul_(head_init_scale)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ trunc_normal_(m.weight, std=.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward_features(self, x):
+ features = []
+ for i in range(4):
+ x = self.downsample_layers[i](x)
+ x = self.stages[i](x)
+ features.append(x)
+ return features # global average pooling, (N, C, H, W) -> (N, C)
+
+ def forward(self, x):
+ #x = self.forward_features(x)
+ #x = self.head(x)
+ features = self.forward_features(x)
+ return features
+
+class LayerNorm(nn.Module):
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+ with shape (batch_size, channels, height, width).
+ """
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.data_format = data_format
+ if self.data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError
+ self.normalized_shape = (normalized_shape, )
+
+ def forward(self, x):
+ if self.data_format == "channels_last":
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+ elif self.data_format == "channels_first":
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
+
+
+model_urls = {
+ "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
+ "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
+ "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
+ "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
+ "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
+ "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
+ "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
+ "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
+ "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
+}
+
+def convnext_tiny(pretrained=True,in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
+ if pretrained:
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
+ #url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
+ #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
+ model_dict = model.state_dict()
+ pretrained_dict = {}
+ unmatched_pretrained_dict = {}
+ for k, v in checkpoint['model'].items():
+ if k in model_dict:
+ pretrained_dict[k] = v
+ else:
+ unmatched_pretrained_dict[k] = v
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict)
+ print(
+ 'Successfully loaded pretrained %d paras, and %d paras are unmatched.'
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
+ print('Unmatched pretrained paras are:', unmatched_pretrained_dict.keys())
+ return model
+
+def convnext_small(pretrained=True,in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
+ if pretrained:
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
+ #url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
+ #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
+ model_dict = model.state_dict()
+ pretrained_dict = {}
+ unmatched_pretrained_dict = {}
+ for k, v in checkpoint['model'].items():
+ if k in model_dict:
+ pretrained_dict[k] = v
+ else:
+ unmatched_pretrained_dict[k] = v
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict)
+ print(
+ 'Successfully loaded pretrained %d paras, and %d paras are unmatched.'
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
+ print('Unmatched pretrained paras are:', unmatched_pretrained_dict.keys())
+ return model
+
+def convnext_base(pretrained=True, in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
+ if pretrained:
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
+ #url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
+ #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
+ model_dict = model.state_dict()
+ pretrained_dict = {}
+ unmatched_pretrained_dict = {}
+ for k, v in checkpoint['model'].items():
+ if k in model_dict:
+ pretrained_dict[k] = v
+ else:
+ unmatched_pretrained_dict[k] = v
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict)
+ print(
+ 'Successfully loaded pretrained %d paras, and %d paras are unmatched.'
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
+ print('Unmatched pretrained paras are:', unmatched_pretrained_dict.keys())
+ return model
+
+def convnext_large(pretrained=True, in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
+ if pretrained:
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
+ #url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
+ #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
+ model_dict = model.state_dict()
+ pretrained_dict = {}
+ unmatched_pretrained_dict = {}
+ for k, v in checkpoint['model'].items():
+ if k in model_dict:
+ pretrained_dict[k] = v
+ else:
+ unmatched_pretrained_dict[k] = v
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict)
+ print(
+ 'Successfully loaded pretrained %d paras, and %d paras are unmatched.'
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
+ print('Unmatched pretrained paras are:', unmatched_pretrained_dict.keys())
+ return model
+
+def convnext_xlarge(pretrained=True, in_22k=False, **kwargs):
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
+ if pretrained:
+ assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
+ checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu")
+ #url = model_urls['convnext_xlarge_22k']
+ #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
+ model_dict = model.state_dict()
+ pretrained_dict = {}
+ unmatched_pretrained_dict = {}
+ for k, v in checkpoint['model'].items():
+ if k in model_dict:
+ pretrained_dict[k] = v
+ else:
+ unmatched_pretrained_dict[k] = v
+ model_dict.update(pretrained_dict)
+ model.load_state_dict(model_dict)
+ print(
+ 'Successfully loaded pretrained %d paras, and %d paras are unmatched.'
+ %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys())))
+ print('Unmatched pretrained paras are:', unmatched_pretrained_dict.keys())
+ return model
+
+if __name__ == '__main__':
+ import torch
+ model = convnext_base(True, in_22k=False).cuda()
+
+ rgb = torch.rand((2, 3, 256, 256)).cuda()
+ out = model(rgb)
+ print(len(out))
+ for i, ft in enumerate(out):
+ print(i, ft.shape)
diff --git a/training/mono/model/backbones/ViT_DINO.py b/training/mono/model/backbones/ViT_DINO.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a1998f0dd5024fbe69895e244fc054245a06568
--- /dev/null
+++ b/training/mono/model/backbones/ViT_DINO.py
@@ -0,0 +1,1504 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable, Optional, Dict, Any, List
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+#from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+logger = logging.getLogger("dinov2")
+
+class ConvBlock(nn.Module):
+ def __init__(self, channels):
+ super(ConvBlock, self).__init__()
+
+ self.act = nn.ReLU(inplace=True)
+ self.conv1 = nn.Conv2d(
+ channels,
+ channels,
+ kernel_size=3,
+ stride=1,
+ padding=1
+ )
+ self.norm1 = nn.BatchNorm2d(channels)
+ self.conv2 = nn.Conv2d(
+ channels,
+ channels,
+ kernel_size=3,
+ stride=1,
+ padding=1
+ )
+ self.norm2 = nn.BatchNorm2d(channels)
+
+ def forward(self, x):
+
+ out = self.norm1(x)
+ out = self.act(out)
+ out = self.conv1(out)
+ out = self.norm2(out)
+ out = self.act(out)
+ out = self.conv2(out)
+ return x + out
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+try:
+ from xformers.ops import SwiGLU
+ #import numpy.bool
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
+
+
+try:
+ from xformers.ops import memory_efficient_attention, unbind, fmha
+ from xformers.components.attention import ScaledDotProduct
+ from xformers.components import MultiHeadDispatch
+ #import numpy.bool
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ window_size: int = 0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ #if not self.training:
+ #
+ # self.attn = ScaledDotProduct()
+ #self.attn = MultiHeadDispatch(dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, attention=attn)
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ if attn_bias is not None:
+ attn = attn + attn_bias[:, :, :N]
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ #if True:
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
+ return super().forward(x, attn_bias)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+ if attn_bias is not None:
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias[:, :, :N])
+ else:
+ x = memory_efficient_attention(q, k, v)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+try:
+ from xformers.ops import fmha
+ from xformers.ops import scaled_index_add, index_select_cat
+ #import numpy.bool
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values = None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ def attn_residual_func(x: Tensor, attn_bias) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ attn_bias=attn_bias
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x, attn_bias))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x, attn_bias)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0, attn_bias=None
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset, attn_bias)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list, attn_bias=None):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list, attn_bias)
+ elif isinstance(x_or_x_list, list):
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x, others=None):
+ for b in self:
+ if others == None:
+ x = b(x)
+ else:
+ x = b(x, others)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ #init_values=None, # for layerscale: None or 0 => no layerscale
+ init_values=1e-5, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=NestedTensorBlock,
+ ffn_layer="mlp",
+ block_chunks=1,
+ window_size=37,
+ **kwargs
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.window_size = window_size
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ w0, h0 = w0 + 0.1, h0 + 0.1
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+ mode="bicubic",
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_patchtokens": x_norm[:, 1:],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ B, C, H, W = x.size()
+ pad_h = (self.patch_size - H % self.patch_size)
+ pad_w = (self.patch_size - W % self.patch_size)
+ if pad_h == self.patch_size:
+ pad_h = 0
+ if pad_w == self.patch_size:
+ pad_w = 0
+ #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2))
+ if pad_h + pad_w > 0:
+ x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear')
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ features = []
+ for blk in self.blocks:
+ x = blk(x)
+ # for idx in range(len(self.blocks[0])):
+ # x = self.blocks[0][idx](x)
+ # if (idx + 1) % (len(self.blocks[0]) // 4) == 0:
+ # features.append(x)
+
+ #return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
+
+ x_norm = self.norm(x)
+ # return {
+ # "x_norm_clstoken": x_norm[:, 0],
+ # "x_norm_patchtokens": x_norm[:, 1:],
+ # "x_prenorm": x,
+ # "masks": masks,
+ # }
+ features = []
+ features.append(x_norm)
+ features.append(x_norm)
+ features.append(x_norm)
+ features.append(x_norm)
+ return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ return ret
+ # if is_training:
+ # return ret
+ # else:
+ # return self.head(ret["x_norm_clstoken"])
+
+
+class PosConv(nn.Module):
+ # PEG from https://arxiv.org/abs/2102.10882
+ def __init__(self, in_chans, embed_dim=768, stride=1):
+ super(PosConv, self).__init__()
+ self.proj = nn.Sequential(
+ nn.Conv2d(in_chans, embed_dim, 37, stride, 18, bias=True, groups=embed_dim),
+ )
+ self.stride = stride
+
+ def forward(self, x, size):
+ B, N, C = x.shape
+ cnn_feat_token = x.transpose(1, 2).view(B, C, *size)
+ x = self.proj(cnn_feat_token)
+ if self.stride == 1:
+ x += cnn_feat_token
+ x = x.flatten(2).transpose(1, 2)
+ return x
+
+ #def no_weight_decay(self):
+ #return ['proj.%d.weight' % i for i in range(4)]
+
+class DinoWindowVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ #init_values=None, # for layerscale: None or 0 => no layerscale
+ init_values=1e-5, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=NestedTensorBlock,
+ ffn_layer="mlp",
+ block_chunks=1,
+ window_size=7,
+ **kwargs
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ #self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ #self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+
+ self.pos_conv = PosConv(self.embed_dim, self.embed_dim)
+
+ self.window_size = window_size
+ #self.conv_block = nn.ModuleList([ConvBlock(embed_dim) for i in range(4)])
+ #self.conv_block = nn.ModuleList([nn.Identity() for i in range(4)])
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.nh = -1
+ self.nw = -1
+ try:
+ H = cfg.data_basic['crop_size'][0]
+ W = cfg.data_basic['crop_size'][1]
+ pad_h = (self.patch_size - H % self.patch_size)
+ pad_w = (self.patch_size - W % self.patch_size)
+ if pad_h == self.patch_size:
+ pad_h = 0
+ if pad_w == self.patch_size:
+ pad_w = 0
+ self.nh = (H + pad_h) // self.patch_size
+ self.nw = (W + pad_w) // self.patch_size
+ self.prepare_attn_bias((self.nh, self.nw))
+ except:
+ pass
+ self.init_weights()
+
+ self.total_step = 10000 # For PE -> GPE transfer
+ self.start_step = 2000
+ self.current_step = 20000
+
+ def init_weights(self):
+ #trunc_normal_(self.pos_embed, std=0.02)
+ #nn.init.normal_(self.cls_token, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+ for i in range(4):
+ try:
+ nn.init.constant_(self.conv_block[i].conv2.weight, 0.0)
+ except:
+ pass
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ #npatch = x.shape[1] - 1
+ #N = self.pos_embed.shape[1] - 1
+ npatch = x.shape[1]
+ N = self.pos_embed.shape[1]
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ #class_pos_embed = pos_embed[:, 0]
+ #patch_pos_embed = pos_embed[:, 1:]
+ patch_pos_embed = pos_embed
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ w0, h0 = w0 + 0.1, h0 + 0.1
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+ mode="bicubic",
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return patch_pos_embed.to(previous_dtype)
+ #return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def window_partition(self, x: torch.Tensor, window_size: int, hw: Tuple[int, int], conv_feature=False) -> Tuple[torch.Tensor, Tuple[int, int]]:
+ """
+ Partition into non-overlapping windows with padding if needed.
+ Args:
+ x (tensor): input tokens with [B, H, W, C].
+ window_size (int): window size.
+
+ Returns:
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
+ (Hp, Wp): padded height and width before partition
+ """
+ if conv_feature == False:
+ B, N, C = x.shape
+ H, W = hw[0], hw[1]
+
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size * window_size, C)
+ else:
+ B, C, H, W = x.shape
+
+ x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
+
+ windows = x.permute(0, 2, 4, 3, 5, 1).contiguous().view(-1, window_size * window_size, C)
+
+ #y = torch.cat((x_cls, windows), dim=1)
+ return windows #, (Hp, Wp)
+
+
+ def window_unpartition(self,
+ windows: torch.Tensor, window_size: int, hw: Tuple[int, int], conv_feature=False
+ ) -> torch.Tensor:
+ """
+ Window unpartition into original sequences and removing padding.
+ Args:
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
+ window_size (int): window size.
+ pad_hw (Tuple): padded height and width (Hp, Wp).
+ hw (Tuple): original height and width (H, W) before padding.
+
+ Returns:
+ x: unpartitioned sequences with [B, H, W, C].
+ """
+ H, W = hw
+
+ B = windows.shape[0] // (H * W // window_size // window_size)
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+
+ if conv_feature == False:
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp * Wp, -1)
+ else:
+ C = windows.shape[-1]
+ x = x.permute(0, 5, 1, 3, 2, 4).contiguous().view(B, C, H, W)
+
+ # if Hp > H or Wp > W:
+ # x = x[:, :H, :W, :].contiguous()
+ return x
+
+ def prepare_tokens_with_masks(self, x, masks=None, step=-1):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ #x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ if step == -1:
+ step = self.current_step
+ else:
+ self.current_step = step
+
+ if step < self.start_step:
+ coef = 0.0
+ elif step < self.total_step:
+ coef = (step - self.start_step) / (self.total_step - self.start_step)
+ else:
+ coef = 1.0
+
+ x = x + (1 - coef) * self.interpolate_pos_encoding(x, w, h) + coef * self.pos_conv(x, (self.nh, self.nw))
+
+ return x
+
+ def prepare_attn_bias(self, shape):
+ window_size = self.window_size
+ if window_size <= 0:
+ return
+
+ import xformers.components.attention.attention_patterns as AP
+
+ nh, nw = shape
+ radius = (window_size-1)//2
+ mask_ori = AP.local_2d_pattern(nh, nw, distance = radius + 0.1, p=torch.inf).cuda()
+
+ pad = (8 - (nh * nw) % 8)
+ if pad == 8:
+ pad = 0
+ mask_pad = nn.functional.pad(mask_ori, (0, pad)).contiguous()
+ if pad > 0:
+ mask = mask_pad[:, :-pad].view(nh, nw, nh, nw)
+ else:
+ mask = mask_pad[:, :].view(nh, nw, nh, nw)
+
+ # angle
+ mask[:radius+1, :radius+1, :window_size, :window_size] = True
+ mask[:radius+1, -radius-1:, :window_size, -window_size:] = True
+ mask[-radius-1:, :radius+1, -window_size:, :window_size] = True
+ mask[-radius-1:, -radius-1:, -window_size:, -window_size:] = True
+
+ # edge
+ mask[radius+1:-radius-1, :radius+1, :, :] = mask[radius+1:-radius-1, radius:radius+1, :, :]
+ mask[radius+1:-radius-1, -radius-1:, :, :] = mask[radius+1:-radius-1, -radius-1:-radius, :, :]
+ mask[:radius+1, radius+1:-radius-1, :, :] = mask[radius:radius+1, radius+1:-radius-1, :, :]
+ mask[-radius-1:, radius+1:-radius-1, :, :] = mask[-radius-1:-radius, radius+1:-radius-1, :, :]
+
+ mask = mask.view(nh*nw, nh*nw)
+ bias_pad = torch.log(mask_pad)
+ #bias = bias_pad[:, :-pad]
+ self.register_buffer('attn_bias', bias_pad)
+
+ return bias_pad
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_patchtokens": x_norm[:, 1:],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None, **kwargs):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ B, C, H, W = x.size()
+ pad_h = (self.patch_size - H % self.patch_size)
+ pad_w = (self.patch_size - W % self.patch_size)
+ if pad_h == self.patch_size:
+ pad_h = 0
+ if pad_w == self.patch_size:
+ pad_w = 0
+ #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2))
+ if pad_h + pad_w > 0:
+ x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear')
+
+ nh = (H+pad_h)//self.patch_size
+ nw = (W+pad_w)//self.patch_size
+
+ if self.window_size > 0:
+ if nh == self.nh and nw == self.nw:
+ attn_bias = self.attn_bias
+ else:
+ attn_bias = self.prepare_attn_bias(((H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size))
+ self.nh = nh
+ self.nw = nw
+ attn_bias = attn_bias.unsqueeze(0).repeat(B * self.num_heads, 1, 1)
+ else:
+ attn_bias = None
+
+ x = self.prepare_tokens_with_masks(x, masks)
+ #x = self.patch_embed(x)
+
+ features = []
+ #x = self.window_partition(x, self.window_size, (H // self.patch_size, W // self.patch_size))
+ for blk in self.blocks:
+ x = blk(x, attn_bias)
+ #x = self.window_unpartition(x, self.window_size, (H // self.patch_size, W // self.patch_size))
+
+ # for idx in range(len(self.blocks[0])):
+ # x = self.blocks[0][idx](x, attn_bias)
+
+ # if (idx + 1) % (len(self.blocks[0]) // 4) == 0:
+ # x = self.window_unpartition(x, self.window_size, (H // self.patch_size, W // self.patch_size), conv_feature=True)
+ # x = self.conv_block[idx // (len(self.blocks[0]) // 4)](x)
+ # if idx + 1 != len(self.blocks[0]):
+ # x = self.window_partition(x, self.window_size, (H // self.patch_size, W // self.patch_size), conv_feature=True)
+ # else:
+ # b, c, h, w = x.size()
+ # x = x.permute(0, 2, 3, 1).contiguous().view(b, h, w, c)
+ #features.append(x)
+
+ #return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
+
+ x_norm = self.norm(x)
+ # return {
+ # "x_norm_clstoken": x_norm[:, 0],
+ # "x_norm_patchtokens": x_norm[:, 1:],
+ # "x_prenorm": x,
+ # "masks": masks,
+ # }
+ features = []
+ features.append(x_norm)
+ features.append(x_norm)
+ features.append(x_norm)
+ features.append(x_norm)
+ return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)]
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ return ret
+ # if is_training:
+ # return ret
+ # else:
+ # return self.head(ret["x_norm_clstoken"])
+
+
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=14, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=14, **kwargs):
+ model = DinoWindowVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=14, checkpoint=None, **kwargs):
+ model = DinoVisionTransformer(
+ img_size = 518,
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
+ **kwargs,
+ )
+
+ if checkpoint is not None:
+ with open(checkpoint, "rb") as f:
+ state_dict = torch.load(f)
+ try:
+ model.load_state_dict(state_dict, strict=True)
+ except:
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ if 'blocks' in key:
+ key_new = 'blocks.0' + key[len('blocks'):]
+ else:
+ key_new = key
+ new_state_dict[key_new] = value
+
+ model.load_state_dict(new_state_dict, strict=True)
+ #del model.norm
+ del model.mask_token
+ return model
+
+ # model = DinoWindowVisionTransformer(
+ # img_size = 518,
+ # patch_size=patch_size,
+ # embed_dim=1024,
+ # depth=24,
+ # num_heads=16,
+ # mlp_ratio=4,
+ # block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention),
+ # window_size=37,
+ # **kwargs,
+ # )
+
+ # if checkpoint is not None:
+ # with open(checkpoint, "rb") as f:
+ # state_dict = torch.load(f)
+ # try:
+ # model.load_state_dict(state_dict, strict=True)
+ # except:
+ # new_state_dict = {}
+ # for key, value in state_dict.items():
+ # if 'blocks' in key:
+ # key_new = 'blocks.0' + key[len('blocks'):]
+ # else:
+ # key_new = key
+ # if 'pos_embed' in key:
+ # value = value[:, 1:, :]
+ # new_state_dict[key_new] = value
+
+ # model.load_state_dict(new_state_dict, strict=False)
+ # #del model.norm
+ # del model.mask_token
+ return model
+
+
+def vit_giant2(patch_size=16, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+if __name__ == '__main__':
+ try:
+ from mmcv.utils import Config
+ except:
+ from mmengine import Config
+
+ #rgb = torch.rand((2, 3, 518, 518)).cuda()
+
+ #cfg.data_basic['crop_size']['0']
+ #cfg.data_basic['crop_size']['1']
+ cfg = Config.fromfile('/cpfs01/user/mu.hu/monodepth/mono/configs/HourglassDecoder/pub12.convlarge.0.3_150.py')
+
+ #rgb = torch.arange(0, 2*3*1036*1036, 1).cuda().float().view(2, 3, 1036, 1036)
+ rgb = torch.zeros(1, 3, 1400, 1680).cuda()
+ model = vit_large(checkpoint="/cpfs02/shared/public/custom/group_local_map/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth", kwarg=cfg).cuda()
+
+ #import timm
+ #model2 = timm.models.vision_transformer.vit_large_patch14_dinov2().cuda()
+ #timm.models.load_checkpoint(model2, '/cpfs02/shared/public/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth', filter_fn=timm.models.vision_transformer.checkpoint_filter_fn)
+
+ out1 = model(rgb)
+ #out2 = model2(rgb)
+ temp = 0
+
+
+
+# import time
+# window_size = 37
+# def prepare_window_masks(shape):
+# if window_size <= 0:
+# return None
+# import xformers.components.attention.attention_patterns as AP
+
+# B, nh, nw, _, _ = shape
+# radius = (window_size-1)//2
+# #time0 = time.time()
+# d = AP.local_nd_distance(nh, nw, distance = radius + 0.1, p=torch.inf).cuda()
+# #mask = AP.local_2d_pattern(nh, nw, distance = radius + 0.1, p=torch.inf).cuda()
+# # mask = mask.view(nh, nw, nh, nw)
+# # #time1 = time.time() - time0
+
+# # # angle
+# # mask[:radius+1, :radius+1, :window_size, :window_size] = True
+# # mask[:radius+1, -radius-1:, :window_size, -window_size:] = True
+# # mask[-radius-1:, :radius+1, -window_size:, :window_size] = True
+# # mask[-radius-1:, -radius-1:, -window_size:, -window_size:] = True
+# # time2 = time.time() - time0 - time1
+
+# # # edge
+# # mask[radius+1:-radius-1, :radius+1, :, :] = mask[radius+1:-radius-1, radius:radius+1, :, :]
+# # mask[radius+1:-radius-1, -radius-1:, :, :] = mask[radius+1:-radius-1, -radius-1:-radius, :, :]
+# # mask[:radius+1, radius+1:-radius-1, :, :] = mask[radius:radius+1, radius+1:-radius-1, :, :]
+# # mask[-radius-1:, radius+1:-radius-1, :, :] = mask[-radius-1:-radius, radius+1:-radius-1, :, :]
+# # time3 = time.time() - time0 - time2
+# # print(time1, time2, time3)
+
+# # return mask.view(nw*nw, nh*nw).unsqueeze(0).repeat(B, 1)
+
+# shape = (1, 55, 55, None, None)
+# mask = prepare_window_masks(shape)
+# # temp = 1
\ No newline at end of file
diff --git a/training/mono/model/backbones/ViT_DINO_reg.py b/training/mono/model/backbones/ViT_DINO_reg.py
new file mode 100644
index 0000000000000000000000000000000000000000..89bcbdc58111a57f2f0a44a1560ad2f99534764b
--- /dev/null
+++ b/training/mono/model/backbones/ViT_DINO_reg.py
@@ -0,0 +1,1099 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable, Optional, Dict, Any, List
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+#from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+logger = logging.getLogger("dinov2")
+
+class ConvBlock(nn.Module):
+ def __init__(self, channels):
+ super(ConvBlock, self).__init__()
+
+ self.act = nn.ReLU(inplace=True)
+ self.conv1 = nn.Conv2d(
+ channels,
+ channels,
+ kernel_size=3,
+ stride=1,
+ padding=1
+ )
+ self.norm1 = nn.BatchNorm2d(channels)
+ self.conv2 = nn.Conv2d(
+ channels,
+ channels,
+ kernel_size=3,
+ stride=1,
+ padding=1
+ )
+ self.norm2 = nn.BatchNorm2d(channels)
+
+ def forward(self, x):
+
+ out = self.norm1(x)
+ out = self.act(out)
+ out = self.conv1(out)
+ out = self.norm2(out)
+ out = self.act(out)
+ out = self.conv2(out)
+ return x + out
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+try:
+ from xformers.ops import SwiGLU
+ #import numpy.bool
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
+
+
+try:
+ from xformers.ops import memory_efficient_attention, unbind, fmha
+ from xformers.components.attention import ScaledDotProduct
+ from xformers.components import MultiHeadDispatch
+ #import numpy.bool
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ window_size: int = 0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ #if not self.training:
+ #
+ # self.attn = ScaledDotProduct()
+ #self.attn = MultiHeadDispatch(dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, attention=attn)
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ if attn_bias is not None:
+ attn = attn + attn_bias[:, :, :N]
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE:
+ #if True:
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
+ return super().forward(x, attn_bias)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+ if attn_bias is not None:
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias[:, :, :N])
+ else:
+ x = memory_efficient_attention(q, k, v)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+try:
+ from xformers.ops import fmha
+ from xformers.ops import scaled_index_add, index_select_cat
+ #import numpy.bool
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values = None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ def attn_residual_func(x: Tensor, attn_bias) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias))
+
+ def ffn_residual_func(x: Tensor) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ attn_bias=attn_bias
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x, attn_bias))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x, attn_bias)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: Tensor,
+ residual_func: Callable[[Tensor], Tensor],
+ sample_drop_ratio: float = 0.0, attn_bias=None
+) -> Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset, attn_bias)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+ else:
+ x_plus_residual = scaled_index_add(
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[Tensor],
+ residual_func: Callable[[Tensor, Any], Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list, attn_bias=None):
+ if isinstance(x_or_x_list, Tensor):
+ return super().forward(x_or_x_list, attn_bias)
+ elif isinstance(x_or_x_list, list):
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x, others=None):
+ for b in self:
+ if others == None:
+ x = b(x)
+ else:
+ x = b(x, others)
+ return x
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=518,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=1e-5, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.1,
+ multi_output=False,
+ **kwargs
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
+ """
+ super().__init__()
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 1
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ self.multi_output = multi_output
+ assert num_register_tokens >= 0
+ self.register_tokens = (
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ logger.info("using MLP layer as FFN")
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ logger.info("using SwiGLU layer as FFN")
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+ logger.info("using Identity layer as FFN")
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ if block_chunks > 0:
+ self.chunked_blocks = True
+ chunked_blocks = []
+ chunksize = depth // block_chunks
+ for i in range(0, depth, chunksize):
+ # this is to keep the block index consistent if we chunk the block list
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+ else:
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Identity()
+
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.register_tokens is not None:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+ # we add a small number to avoid floating point error in the interpolation
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
+
+ sqrt_N = math.sqrt(N)
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
+ scale_factor=(sx, sy),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ )
+
+ assert int(w0) == patch_pos_embed.shape[-2]
+ assert int(h0) == patch_pos_embed.shape[-1]
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ x = self.patch_embed(x)
+ if masks is not None:
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.register_tokens is not None:
+ x = torch.cat(
+ (
+ x[:, :1],
+ self.register_tokens.expand(x.shape[0], -1, -1),
+ x[:, 1:],
+ ),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features_list(self, x_list, masks_list):
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+ for blk in self.blocks:
+ x = blk(x)
+
+ all_x = x
+ output = []
+ for x, masks in zip(all_x, masks_list):
+ x_norm = self.norm(x)
+ output.append(
+ {
+ "x_norm_clstoken": x_norm[:, 0],
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ "x_prenorm": x,
+ "masks": masks,
+ }
+ )
+ return output
+
+ def forward_features(self, x, masks=None):
+ if isinstance(x, list):
+ return self.forward_features_list(x, masks)
+
+ B, C, H, W = x.size()
+ pad_h = (self.patch_size - H % self.patch_size)
+ pad_w = (self.patch_size - W % self.patch_size)
+ if pad_h == self.patch_size:
+ pad_h = 0
+ if pad_w == self.patch_size:
+ pad_w = 0
+ #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2))
+ if pad_h + pad_w > 0:
+ x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear')
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ # return {
+ # "x_norm_clstoken": x_norm[:, 0],
+ # "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
+ # "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
+ # "x_prenorm": x,
+ # "masks": masks,
+ # }
+ if self.multi_output == False:
+ for blk in self.blocks:
+ x = blk(x)
+ x_norm = self.norm(x)
+ features = []
+ features.append(x_norm)
+ features.append(x_norm)
+ features.append(x_norm)
+ features.append(x_norm)
+ return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W, self.num_register_tokens)]
+ else:
+ features = []
+ for blk in self.blocks:
+ for idx, sub_blk in enumerate(blk):
+ x = sub_blk(x)
+ if (idx + 1) % (len(blk) // 4) == 0:
+ features.append(x)
+
+ return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W, self.num_register_tokens)]
+
+
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ # If n is an int, take the n last blocks. If it's a list, take them
+ output, total_block_len = [], len(self.blocks)
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def _get_intermediate_layers_chunked(self, x, n=1):
+ x = self.prepare_tokens_with_masks(x)
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
+ # If n is an int, take the n last blocks. If it's a list, take them
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+ for block_chunk in self.blocks:
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
+ x = blk(x)
+ if i in blocks_to_take:
+ output.append(x)
+ i += 1
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+ return output
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
+ reshape: bool = False,
+ return_class_token: bool = False,
+ norm=True,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ if self.chunked_blocks:
+ outputs = self._get_intermediate_layers_chunked(x, n)
+ else:
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, 0] for out in outputs]
+ outputs = [out[:, 1:] for out in outputs]
+ if reshape:
+ B, _, w, h = x.shape
+ outputs = [
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+ for out in outputs
+ ]
+ if return_class_token:
+ return tuple(zip(outputs, class_tokens))
+ return tuple(outputs)
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ return ret
+ # if is_training:
+ # return ret
+ # else:
+ # return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def load_ckpt_dino(checkpoint, model, reserve_norm=True):
+ if checkpoint is not None:
+ with open(checkpoint, "rb") as f:
+ state_dict = torch.load(f)
+ try:
+ model.load_state_dict(state_dict, strict=True)
+ except:
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ if 'blocks' in key:
+ key_new = 'blocks.0' + key[len('blocks'):]
+ else:
+ key_new = key
+ new_state_dict[key_new] = value
+
+ model.load_state_dict(new_state_dict, strict=True)
+ del model.mask_token
+ if reserve_norm == False:
+ del model.norm
+ return
+ else:
+ return
+
+
+def vit_small(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+
+ load_ckpt_dino(checkpoint, model)
+
+ return model
+
+
+def vit_base(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+
+ if checkpoint is not None:
+ with open(checkpoint, "rb") as f:
+ state_dict = torch.load(f)
+ try:
+ model.load_state_dict(state_dict, strict=True)
+ except:
+ new_state_dict = {}
+ for key, value in state_dict.items():
+ if 'blocks' in key:
+ key_new = 'blocks.0' + key[len('blocks'):]
+ else:
+ key_new = key
+ new_state_dict[key_new] = value
+
+ model.load_state_dict(new_state_dict, strict=True)
+ del model.mask_token
+ return model
+
+
+def vit_giant2(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ ffn_layer='swiglu',
+ **kwargs,
+ )
+ return model
+
+
+
+def vit_small_reg(patch_size=14, num_register_tokens=4, checkpoint=None, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+
+ load_ckpt_dino(checkpoint, model)
+
+ return model
+
+
+def vit_base_reg(patch_size=14, num_register_tokens=4, checkpoint=None, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+
+ load_ckpt_dino(checkpoint, model)
+
+ return model
+
+
+def vit_large_reg(patch_size=14, num_register_tokens=4, checkpoint=None, **kwargs):
+ model = DinoVisionTransformer(
+ img_size = 518,
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ **kwargs,
+ )
+
+ load_ckpt_dino(checkpoint, model)
+
+ return model
+
+
+def vit_giant2_reg(patch_size=14, num_register_tokens=4, checkpoint=None, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ num_register_tokens=num_register_tokens,
+ ffn_layer='swiglu',
+ multi_output=True,
+ **kwargs,
+ )
+
+ load_ckpt_dino(checkpoint, model, reserve_norm=False)
+
+ return model
+
+if __name__ == '__main__':
+ try:
+ from mmcv.utils import Config
+ except:
+ from mmengine import Config
+
+ #rgb = torch.rand((2, 3, 518, 518)).cuda()
+
+ #cfg.data_basic['crop_size']['0']
+ #cfg.data_basic['crop_size']['1']
+ cfg = Config.fromfile('/cpfs01/shared/public/users/mu.hu/monodepth/mono/configs/RAFTDecoder/vit.raft.full2t.py')
+
+ #rgb = torch.arange(0, 2*3*1036*1036, 1).cuda().float().view(2, 3, 1036, 1036)
+ rgb = torch.zeros(1, 3, 616, 1064).cuda()
+ #model = vit_large_reg(checkpoint="/cpfs02/shared/public/groups/local_map/yvan/pretrained_weight_repo/vit/dinov2_vitl14_reg4_pretrain.pth", kwarg=cfg).cuda()
+ model = vit_giant2_reg(checkpoint="/cpfs02/shared/public/groups/local_map/yvan/pretrained_weight_repo/vit/dinov2_vitg14_reg4_pretrain.pth", kwarg=cfg).cuda()
+
+ #import timm
+ #model2 = timm.models.vision_transformer.vit_large_patch14_dinov2().cuda()
+ #timm.models.load_checkpoint(model2, '/cpfs02/shared/public/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth', filter_fn=timm.models.vision_transformer.checkpoint_filter_fn)
+
+ out1 = model(rgb)
+ #out2 = model2(rgb)
+ temp = 0
+
+
diff --git a/training/mono/model/backbones/__init__.py b/training/mono/model/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..51577dcd12c51c16191080c0c5954e0bcd3896c4
--- /dev/null
+++ b/training/mono/model/backbones/__init__.py
@@ -0,0 +1,8 @@
+from .ViT_DINO import vit_large
+from .ViT_DINO_reg import vit_small_reg, vit_large_reg, vit_giant2_reg
+
+__all__ = [
+ "vit_small_reg",
+ "vit_large_reg",
+ "vit_giant2_reg",
+]
diff --git a/training/mono/model/criterion.py b/training/mono/model/criterion.py
new file mode 100644
index 0000000000000000000000000000000000000000..4185dcc1912b249bd2bd3a77ff078103867e8501
--- /dev/null
+++ b/training/mono/model/criterion.py
@@ -0,0 +1,62 @@
+from .losses import *
+from mono.utils.comm import get_func
+import os
+
+def build_from_cfg(cfg, default_args=None):
+ """Build a module from config dict.
+ Args:
+ cfg (dict): Config dict. It should at least contain the key "type".
+ default_args (dict, optional): Default initialization arguments.
+ Returns:
+ object: The constructed object.
+ """
+ if not isinstance(cfg, dict):
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
+ if 'type' not in cfg:
+ raise RuntimeError('should contain the loss name')
+ args = cfg.copy()
+
+ obj_name = args.pop('type')
+ obj_path = os.path.dirname(__file__).split(os.getcwd() + '/')[-1].replace('/', '.') + '.losses.' + obj_name
+
+ obj_cls = get_func(obj_path)(**args)
+
+ if obj_cls is None:
+ raise KeyError(f'cannot find {obj_name}.')
+ return obj_cls
+
+
+
+
+def build_criterions(cfg):
+ if 'losses' not in cfg:
+ raise RuntimeError('Losses have not been configured.')
+ cfg_data_basic = cfg.data_basic
+
+ criterions = dict()
+ losses = cfg.losses
+ if not isinstance(losses, dict):
+ raise RuntimeError(f'Cannot initial losses with the type {type(losses)}')
+ for key, loss_list in losses.items():
+ criterions[key] = []
+ for loss_cfg_i in loss_list:
+ # update the canonical_space configs to the current loss cfg
+ loss_cfg_i.update(cfg_data_basic)
+ if 'out_channel' in loss_cfg_i:
+ loss_cfg_i.update(out_channel=cfg.out_channel) # classification loss need to update the channels
+ obj_cls = build_from_cfg(loss_cfg_i)
+ criterions[key].append(obj_cls)
+ return criterions
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/training/mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py b/training/mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py
new file mode 100644
index 0000000000000000000000000000000000000000..87aa3a23bc64494a48fc084f765ff3150eb25396
--- /dev/null
+++ b/training/mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py
@@ -0,0 +1,818 @@
+import torch
+import torch.nn as nn
+import numpy as np
+import math
+import torch.nn.functional as F
+
+def compute_depth_expectation(prob, depth_values):
+ depth_values = depth_values.view(*depth_values.shape, 1, 1)
+ depth = torch.sum(prob * depth_values, 1)
+ return depth
+
+def interpolate_float32(x, size=None, scale_factor=None, mode='nearest', align_corners=None):
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
+ return F.interpolate(x.float(), size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
+
+# def upflow8(flow, mode='bilinear'):
+# new_size = (8 * flow.shape[2], 8 * flow.shape[3])
+# return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
+
+def upflow4(flow, mode='bilinear'):
+ new_size = (4 * flow.shape[2], 4 * flow.shape[3])
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
+ return F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
+
+def coords_grid(batch, ht, wd):
+ # coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
+ coords = (torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)))
+ coords = torch.stack(coords[::-1], dim=0).float()
+ return coords[None].repeat(batch, 1, 1, 1)
+
+def norm_normalize(norm_out):
+ min_kappa = 0.01
+ norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1)
+ norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10
+ kappa = F.elu(kappa) + 1.0 + min_kappa
+ final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1)
+ return final_out
+
+# uncertainty-guided sampling (only used during training)
+@torch.no_grad()
+def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta):
+ device = init_normal.device
+ B, _, H, W = init_normal.shape
+ N = int(sampling_ratio * H * W)
+ beta = beta
+
+ # uncertainty map
+ uncertainty_map = -1 * init_normal[:, -1, :, :] # B, H, W
+
+ # gt_invalid_mask (B, H, W)
+ if gt_norm_mask is not None:
+ gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest')
+ gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5
+ uncertainty_map[gt_invalid_mask] = -1e4
+
+ # (B, H*W)
+ _, idx = uncertainty_map.view(B, -1).sort(1, descending=True)
+
+ # importance sampling
+ if int(beta * N) > 0:
+ importance = idx[:, :int(beta * N)] # B, beta*N
+
+ # remaining
+ remaining = idx[:, int(beta * N):] # B, H*W - beta*N
+
+ # coverage
+ num_coverage = N - int(beta * N)
+
+ if num_coverage <= 0:
+ samples = importance
+ else:
+ coverage_list = []
+ for i in range(B):
+ idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
+ coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
+ coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
+ samples = torch.cat((importance, coverage), dim=1) # B, N
+
+ else:
+ # remaining
+ remaining = idx[:, :] # B, H*W
+
+ # coverage
+ num_coverage = N
+
+ coverage_list = []
+ for i in range(B):
+ idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
+ coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
+ coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
+ samples = coverage
+
+ # point coordinates
+ rows_int = samples // W # 0 for first row, H-1 for last row
+ rows_float = rows_int / float(H-1) # 0 to 1.0
+ rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0
+
+ cols_int = samples % W # 0 for first column, W-1 for last column
+ cols_float = cols_int / float(W-1) # 0 to 1.0
+ cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0
+
+ point_coords = torch.zeros(B, 1, N, 2)
+ point_coords[:, 0, :, 0] = cols_float # x coord
+ point_coords[:, 0, :, 1] = rows_float # y coord
+ point_coords = point_coords.to(device)
+ return point_coords, rows_int, cols_int
+
+class FlowHead(nn.Module):
+ def __init__(self, input_dim=128, hidden_dim=256, output_dim_depth=2, output_dim_norm=4):
+ super(FlowHead, self).__init__()
+ self.conv1d = nn.Conv2d(input_dim, hidden_dim // 2, 3, padding=1)
+ self.conv2d = nn.Conv2d(hidden_dim // 2, output_dim_depth, 3, padding=1)
+
+ self.conv1n = nn.Conv2d(input_dim, hidden_dim // 2, 3, padding=1)
+ self.conv2n = nn.Conv2d(hidden_dim // 2, output_dim_norm, 3, padding=1)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ depth = self.conv2d(self.relu(self.conv1d(x)))
+ normal = self.conv2n(self.relu(self.conv1n(x)))
+ return torch.cat((depth, normal), dim=1)
+
+
+class ConvGRU(nn.Module):
+ def __init__(self, hidden_dim, input_dim, kernel_size=3):
+ super(ConvGRU, self).__init__()
+ self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
+ self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
+ self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2)
+
+ def forward(self, h, cz, cr, cq, *x_list):
+ x = torch.cat(x_list, dim=1)
+ hx = torch.cat([h, x], dim=1)
+
+ z = torch.sigmoid((self.convz(hx) + cz))
+ r = torch.sigmoid((self.convr(hx) + cr))
+ q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq))
+
+ # z = torch.sigmoid((self.convz(hx) + cz).float())
+ # r = torch.sigmoid((self.convr(hx) + cr).float())
+ # q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq).float())
+
+ h = (1-z) * h + z * q
+ return h
+
+def pool2x(x):
+ return F.avg_pool2d(x, 3, stride=2, padding=1)
+
+def pool4x(x):
+ return F.avg_pool2d(x, 5, stride=4, padding=1)
+
+def interp(x, dest):
+ interp_args = {'mode': 'bilinear', 'align_corners': True}
+ return interpolate_float32(x, dest.shape[2:], **interp_args)
+
+class BasicMultiUpdateBlock(nn.Module):
+ def __init__(self, args, hidden_dims=[], out_dims=2):
+ super().__init__()
+ self.args = args
+ self.n_gru_layers = args.model.decode_head.n_gru_layers # 3
+ self.n_downsample = args.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K)
+
+ # self.encoder = BasicMotionEncoder(args)
+ # encoder_output_dim = 128 # if there is corr volume
+ encoder_output_dim = 6 # no corr volume
+
+ self.gru08 = ConvGRU(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (self.n_gru_layers > 1))
+ self.gru16 = ConvGRU(hidden_dims[1], hidden_dims[0] * (self.n_gru_layers == 3) + hidden_dims[2])
+ self.gru32 = ConvGRU(hidden_dims[0], hidden_dims[1])
+ self.flow_head = FlowHead(hidden_dims[2], hidden_dim=2*hidden_dims[2])
+ factor = 2**self.n_downsample
+
+ self.mask = nn.Sequential(
+ nn.Conv2d(hidden_dims[2], hidden_dims[2], 3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(hidden_dims[2], (factor**2)*9, 1, padding=0))
+
+ def forward(self, net, inp, corr=None, flow=None, iter08=True, iter16=True, iter32=True, update=True):
+
+ if iter32:
+ net[2] = self.gru32(net[2], *(inp[2]), pool2x(net[1]))
+ if iter16:
+ if self.n_gru_layers > 2:
+ net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1]), interp(net[2], net[1]))
+ else:
+ net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1]))
+ if iter08:
+ if corr is not None:
+ motion_features = self.encoder(flow, corr)
+ else:
+ motion_features = flow
+ if self.n_gru_layers > 1:
+ net[0] = self.gru08(net[0], *(inp[0]), motion_features, interp(net[1], net[0]))
+ else:
+ net[0] = self.gru08(net[0], *(inp[0]), motion_features)
+
+ if not update:
+ return net
+
+ delta_flow = self.flow_head(net[0])
+
+ # scale mask to balence gradients
+ mask = .25 * self.mask(net[0])
+ return net, mask, delta_flow
+
+class LayerNorm2d(nn.LayerNorm):
+ def __init__(self, dim):
+ super(LayerNorm2d, self).__init__(dim)
+
+ def forward(self, x):
+ x = x.permute(0, 2, 3, 1).contiguous()
+ x = super(LayerNorm2d, self).forward(x)
+ x = x.permute(0, 3, 1, 2).contiguous()
+ return x
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == 'group':
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == 'batch':
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == 'instance':
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == 'layer':
+ self.norm1 = LayerNorm2d(planes)
+ self.norm2 = LayerNorm2d(planes)
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = LayerNorm2d(planes)
+
+ elif norm_fn == 'none':
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not (stride == 1 and in_planes == planes):
+ self.norm3 = nn.Sequential()
+
+ if stride == 1 and in_planes == planes:
+ self.downsample = None
+
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
+
+ def forward(self, x):
+ y = x
+ y = self.conv1(y)
+ y = self.norm1(y)
+ y = self.relu(y)
+ y = self.conv2(y)
+ y = self.norm2(y)
+ y = self.relu(y)
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x+y)
+
+
+class ContextFeatureEncoder(nn.Module):
+ '''
+ Encoder features are used to:
+ 1. initialize the hidden state of the update operator
+ 2. and also injected into the GRU during each iteration of the update operator
+ '''
+ def __init__(self, in_dim, output_dim):
+ '''
+ in_dim = [x4, x8, x16, x32]
+ output_dim = [hindden_dims, context_dims]
+ [[x4,x8,x16,x32],[x4,x8,x16,x32]]
+ '''
+ super().__init__()
+
+ output_list = []
+ for dim in output_dim:
+ conv_out = nn.Sequential(
+ ResidualBlock(in_dim[0], dim[0], 'layer', stride=1),
+ nn.Conv2d(dim[0], dim[0], 3, padding=1))
+ output_list.append(conv_out)
+
+ self.outputs04 = nn.ModuleList(output_list)
+
+ output_list = []
+ for dim in output_dim:
+ conv_out = nn.Sequential(
+ ResidualBlock(in_dim[1], dim[1], 'layer', stride=1),
+ nn.Conv2d(dim[1], dim[1], 3, padding=1))
+ output_list.append(conv_out)
+
+ self.outputs08 = nn.ModuleList(output_list)
+
+ output_list = []
+ for dim in output_dim:
+ conv_out = nn.Sequential(
+ ResidualBlock(in_dim[2], dim[2], 'layer', stride=1),
+ nn.Conv2d(dim[2], dim[2], 3, padding=1))
+ output_list.append(conv_out)
+
+ self.outputs16 = nn.ModuleList(output_list)
+
+ # output_list = []
+ # for dim in output_dim:
+ # conv_out = nn.Conv2d(in_dim[3], dim[3], 3, padding=1)
+ # output_list.append(conv_out)
+
+ # self.outputs32 = nn.ModuleList(output_list)
+
+ def forward(self, encoder_features):
+ x_4, x_8, x_16, x_32 = encoder_features
+
+ outputs04 = [f(x_4) for f in self.outputs04]
+ outputs08 = [f(x_8) for f in self.outputs08]
+ outputs16 = [f(x_16)for f in self.outputs16]
+ # outputs32 = [f(x_32) for f in self.outputs32]
+
+ return (outputs04, outputs08, outputs16)
+
+class ConvBlock(nn.Module):
+ # reimplementation of DPT
+ def __init__(self, channels):
+ super(ConvBlock, self).__init__()
+
+ self.act = nn.ReLU(inplace=True)
+ self.conv1 = nn.Conv2d(
+ channels,
+ channels,
+ kernel_size=3,
+ stride=1,
+ padding=1
+ )
+ self.conv2 = nn.Conv2d(
+ channels,
+ channels,
+ kernel_size=3,
+ stride=1,
+ padding=1
+ )
+
+ def forward(self, x):
+ out = self.act(x)
+ out = self.conv1(out)
+ out = self.act(out)
+ out = self.conv2(out)
+ return x + out
+
+class FuseBlock(nn.Module):
+ # reimplementation of DPT
+ def __init__(self, in_channels, out_channels, fuse=True, upsample=True, scale_factor=2):
+ super(FuseBlock, self).__init__()
+
+ self.fuse = fuse
+ self.scale_factor = scale_factor
+ self.way_trunk = ConvBlock(in_channels)
+ if self.fuse:
+ self.way_branch = ConvBlock(in_channels)
+
+ self.out_conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+ self.upsample = upsample
+
+ def forward(self, x1, x2=None):
+ if x2 is not None:
+ x2 = self.way_branch(x2)
+ x1 = x1 + x2
+
+ out = self.way_trunk(x1)
+
+ if self.upsample:
+ out = interpolate_float32(
+ out, scale_factor=self.scale_factor, mode="bilinear", align_corners=True
+ )
+ out = self.out_conv(out)
+ return out
+
+class Readout(nn.Module):
+ # From DPT
+ def __init__(self, in_features, use_cls_token=True, num_register_tokens=0):
+ super(Readout, self).__init__()
+ self.use_cls_token = use_cls_token
+ if self.use_cls_token == True:
+ self.project_patch = nn.Linear(in_features, in_features)
+ self.project_learn = nn.Linear((1 + num_register_tokens) * in_features, in_features, bias=False)
+ self.act = nn.GELU()
+ else:
+ self.project = nn.Identity()
+
+ def forward(self, x):
+
+ if self.use_cls_token == True:
+ x_patch = self.project_patch(x[0])
+ x_learn = self.project_learn(x[1])
+ x_learn = x_learn.expand_as(x_patch).contiguous()
+ features = x_patch + x_learn
+ return self.act(features)
+ else:
+ return self.project(x)
+
+class Token2Feature(nn.Module):
+ # From DPT
+ def __init__(self, vit_channel, feature_channel, scale_factor, use_cls_token=True, num_register_tokens=0):
+ super(Token2Feature, self).__init__()
+ self.scale_factor = scale_factor
+ self.readoper = Readout(in_features=vit_channel, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens)
+ if scale_factor > 1 and isinstance(scale_factor, int):
+ self.sample = nn.ConvTranspose2d(
+ in_channels=vit_channel,
+ out_channels=feature_channel,
+ kernel_size=scale_factor,
+ stride=scale_factor,
+ padding=0,
+ )
+
+ elif scale_factor > 1:
+ self.sample = nn.Sequential(
+ # Upsample2(upscale=scale_factor),
+ # nn.Upsample(scale_factor=scale_factor),
+ nn.Conv2d(
+ in_channels=vit_channel,
+ out_channels=feature_channel,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+
+ elif scale_factor < 1:
+ scale_factor = int(1.0 / scale_factor)
+ self.sample = nn.Conv2d(
+ in_channels=vit_channel,
+ out_channels=feature_channel,
+ kernel_size=scale_factor+1,
+ stride=scale_factor,
+ padding=1,
+ )
+
+ else:
+ self.sample = nn.Identity()
+
+ def forward(self, x):
+ x = self.readoper(x)
+ #if use_cls_token == True:
+ x = x.permute(0, 3, 1, 2).contiguous()
+ if isinstance(self.scale_factor, float):
+ x = interpolate_float32(x.float(), scale_factor=self.scale_factor, mode='nearest')
+ x = self.sample(x)
+ return x
+
+class EncoderFeature(nn.Module):
+ def __init__(self, vit_channel, num_ch_dec=[256, 512, 1024, 1024], use_cls_token=True, num_register_tokens=0):
+ super(EncoderFeature, self).__init__()
+ self.vit_channel = vit_channel
+ self.num_ch_dec = num_ch_dec
+
+ self.read_3 = Token2Feature(self.vit_channel, self.num_ch_dec[3], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens)
+ self.read_2 = Token2Feature(self.vit_channel, self.num_ch_dec[2], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens)
+ self.read_1 = Token2Feature(self.vit_channel, self.num_ch_dec[1], scale_factor=2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens)
+ self.read_0 = Token2Feature(self.vit_channel, self.num_ch_dec[0], scale_factor=7/2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens)
+
+ def forward(self, ref_feature):
+ x = self.read_3(ref_feature[3]) # 1/14
+ x2 = self.read_2(ref_feature[2]) # 1/14
+ x1 = self.read_1(ref_feature[1]) # 1/7
+ x0 = self.read_0(ref_feature[0]) # 1/4
+
+ return x, x2, x1, x0
+
+class DecoderFeature(nn.Module):
+ def __init__(self, vit_channel, num_ch_dec=[128, 256, 512, 1024, 1024], use_cls_token=True):
+ super(DecoderFeature, self).__init__()
+ self.vit_channel = vit_channel
+ self.num_ch_dec = num_ch_dec
+
+ self.upconv_3 = FuseBlock(
+ self.num_ch_dec[4],
+ self.num_ch_dec[3],
+ fuse=False, upsample=False)
+
+ self.upconv_2 = FuseBlock(
+ self.num_ch_dec[3],
+ self.num_ch_dec[2],
+ )
+
+ self.upconv_1 = FuseBlock(
+ self.num_ch_dec[2],
+ self.num_ch_dec[1] + 2,
+ scale_factor=7/4
+ )
+
+ # self.upconv_0 = FuseBlock(
+ # self.num_ch_dec[1],
+ # self.num_ch_dec[0] + 1,
+ # )
+
+ def forward(self, ref_feature):
+ x, x2, x1, x0 = ref_feature # 1/14 1/14 1/7 1/4
+
+ x = self.upconv_3(x) # 1/14
+ x = self.upconv_2(x, x2) # 1/7
+ x = self.upconv_1(x, x1) # 1/4
+ # x = self.upconv_0(x, x0) # 4/7
+ return x
+
+class RAFTDepthNormalDPT5(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.in_channels = cfg.model.decode_head.in_channels # [1024, 1024, 1024, 1024]
+ self.feature_channels = cfg.model.decode_head.feature_channels # [256, 512, 1024, 1024] [2/7, 1/7, 1/14, 1/14]
+ self.decoder_channels = cfg.model.decode_head.decoder_channels # [128, 256, 512, 1024, 1024] [-, 1/4, 1/7, 1/14, 1/14]
+ self.use_cls_token = cfg.model.decode_head.use_cls_token
+ self.up_scale = cfg.model.decode_head.up_scale
+ self.num_register_tokens = cfg.model.decode_head.num_register_tokens
+ self.min_val = cfg.data_basic.depth_normalize[0]
+ self.max_val = cfg.data_basic.depth_normalize[1]
+ self.regress_scale = 100.0
+
+ self.hidden_dims = self.context_dims = cfg.model.decode_head.hidden_channels # [128, 128, 128, 128]
+ self.n_gru_layers = cfg.model.decode_head.n_gru_layers # 3
+ self.n_downsample = cfg.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K)
+ self.iters = cfg.model.decode_head.iters # 22
+ self.slow_fast_gru = cfg.model.decode_head.slow_fast_gru # True
+
+ self.num_depth_regressor_anchor = 256 # 512
+ self.used_res_channel = self.decoder_channels[1] # now, use 2/7 res
+ self.token2feature = EncoderFeature(self.in_channels[0], self.feature_channels, self.use_cls_token, self.num_register_tokens)
+ self.decoder_mono = DecoderFeature(self.in_channels, self.decoder_channels)
+ self.depth_regressor = nn.Sequential(
+ nn.Conv2d(self.used_res_channel,
+ self.num_depth_regressor_anchor,
+ kernel_size=3,
+ padding=1),
+ # nn.BatchNorm2d(self.num_depth_regressor_anchor),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(self.num_depth_regressor_anchor,
+ self.num_depth_regressor_anchor,
+ kernel_size=1),
+ )
+ self.normal_predictor = nn.Sequential(
+ nn.Conv2d(self.used_res_channel,
+ 128,
+ kernel_size=3,
+ padding=1),
+ # nn.BatchNorm2d(128),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=1), nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, kernel_size=1), nn.ReLU(inplace=True),
+ nn.Conv2d(128, 3, kernel_size=1),
+ )
+
+ self.context_feature_encoder = ContextFeatureEncoder(self.feature_channels, [self.hidden_dims, self.context_dims])
+ self.context_zqr_convs = nn.ModuleList([nn.Conv2d(self.context_dims[i], self.hidden_dims[i]*3, 3, padding=3//2) for i in range(self.n_gru_layers)])
+ self.update_block = BasicMultiUpdateBlock(cfg, hidden_dims=self.hidden_dims, out_dims=6)
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def get_bins(self, bins_num):
+ depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cuda")
+ depth_bins_vec = torch.exp(depth_bins_vec)
+ return depth_bins_vec
+
+ def register_depth_expectation_anchor(self, bins_num, B):
+ depth_bins_vec = self.get_bins(bins_num)
+ depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1)
+ self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False)
+
+ def clamp(self, x):
+ y = self.relu(x - self.min_val) + self.min_val
+ y = self.max_val - self.relu(self.max_val - y)
+ return y
+
+ def regress_depth(self, feature_map_d):
+ prob_feature = self.depth_regressor(feature_map_d)
+ prob = prob_feature.softmax(dim=1)
+ #prob = prob_feature.float().softmax(dim=1)
+
+ ## Error logging
+ if torch.isnan(prob).any():
+ print('prob_feat_nan!!!')
+ if torch.isinf(prob).any():
+ print('prob_feat_inf!!!')
+
+ # h = prob[0,:,0,0].cpu().numpy().reshape(-1)
+ # import matplotlib.pyplot as plt
+ # plt.bar(range(len(h)), h)
+ B = prob.shape[0]
+ if "depth_expectation_anchor" not in self._buffers:
+ self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, B)
+ d = compute_depth_expectation(
+ prob,
+ self.depth_expectation_anchor[:B, ...]).unsqueeze(1)
+
+ ## Error logging
+ if torch.isnan(d ).any():
+ print('d_nan!!!')
+ if torch.isinf(d ).any():
+ print('d_inf!!!')
+
+ return (self.clamp(d) - self.max_val)/ self.regress_scale, prob_feature
+
+ def pred_normal(self, feature_map, confidence):
+ normal_out = self.normal_predictor(feature_map)
+
+ ## Error logging
+ if torch.isnan(normal_out).any():
+ print('norm_nan!!!')
+ if torch.isinf(normal_out).any():
+ print('norm_feat_inf!!!')
+
+ return norm_normalize(torch.cat([normal_out, confidence], dim=1))
+ #return norm_normalize(torch.cat([normal_out, confidence], dim=1).float())
+
+ def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True):
+ y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device),
+ torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij')
+ meshgrid = torch.stack((x, y))
+ meshgrid = meshgrid.unsqueeze(0).repeat(batch, 1, 1, 1)
+ #self.register_buffer('meshgrid', meshgrid, persistent=False)
+ return meshgrid
+
+ def upsample_flow(self, flow, mask):
+ """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
+ N, D, H, W = flow.shape
+ factor = 2 ** self.n_downsample
+ mask = mask.view(N, 1, 9, factor, factor, H, W)
+ mask = torch.softmax(mask, dim=2)
+ #mask = torch.softmax(mask.float(), dim=2)
+
+ #up_flow = F.unfold(factor * flow, [3,3], padding=1)
+ up_flow = F.unfold(flow, [3,3], padding=1)
+ up_flow = up_flow.view(N, D, 9, 1, 1, H, W)
+
+ up_flow = torch.sum(mask * up_flow, dim=2)
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
+ return up_flow.reshape(N, D, factor*H, factor*W)
+
+ def initialize_flow(self, img):
+ """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
+ N, _, H, W = img.shape
+
+ coords0 = coords_grid(N, H, W).to(img.device)
+ coords1 = coords_grid(N, H, W).to(img.device)
+
+ return coords0, coords1
+
+ def upsample(self, x, scale_factor=2):
+ """Upsample input tensor by a factor of 2
+ """
+ return interpolate_float32(x, scale_factor=scale_factor*self.up_scale/8, mode="nearest")
+
+ def forward(self, vit_features, **kwargs):
+ ## read vit token to multi-scale features
+ B, H, W, _, _, num_register_tokens = vit_features[1]
+ vit_features = vit_features[0]
+
+ ## Error logging
+ if torch.isnan(vit_features[0]).any():
+ print('vit_feature_nan!!!')
+ if torch.isinf(vit_features[0]).any():
+ print('vit_feature_inf!!!')
+
+ if self.use_cls_token == True:
+ vit_features = [[ft[:, 1+num_register_tokens:, :].view(B, H, W, self.in_channels[0]), \
+ ft[:, 0:1+num_register_tokens, :].view(B, 1, 1, self.in_channels[0] * (1+num_register_tokens))] for ft in vit_features]
+ else:
+ vit_features = [ft.view(B, H, W, self.in_channels[0]) for ft in vit_features]
+ encoder_features = self.token2feature(vit_features) # 1/14, 1/14, 1/7, 1/4
+
+ ## Error logging
+ for en_ft in encoder_features:
+ if torch.isnan(en_ft).any():
+ print('decoder_feature_nan!!!')
+ print(en_ft.shape)
+ if torch.isinf(en_ft).any():
+ print('decoder_feature_inf!!!')
+ print(en_ft.shape)
+
+ ## decode features to init-depth (and confidence)
+ ref_feat= self.decoder_mono(encoder_features) # now, 1/4 for depth
+
+ ## Error logging
+ if torch.isnan(ref_feat).any():
+ print('ref_feat_nan!!!')
+ if torch.isinf(ref_feat).any():
+ print('ref_feat_inf!!!')
+
+ feature_map = ref_feat[:, :-2, :, :] # feature map share of depth and normal prediction
+ depth_confidence_map = ref_feat[:, -2:-1, :, :]
+ normal_confidence_map = ref_feat[:, -1:, :, :]
+ depth_pred, binmap = self.regress_depth(feature_map) # regress bin for depth
+ normal_pred = self.pred_normal(feature_map, normal_confidence_map) # mlp for normal
+
+ depth_init = torch.cat((depth_pred, depth_confidence_map, normal_pred), dim=1) # (N, 1+1+4, H, W)
+
+ ## encoder features to context-feature for init-hidden-state and contex-features
+ cnet_list = self.context_feature_encoder(encoder_features[::-1])
+ net_list = [torch.tanh(x[0]) for x in cnet_list] # x_4, x_8, x_16 of hidden state
+ inp_list = [torch.relu(x[1]) for x in cnet_list] # x_4, x_8, x_16 context features
+
+ # Rather than running the GRU's conv layers on the context features multiple times, we do it once at the beginning
+ inp_list = [list(conv(i).split(split_size=conv.out_channels//3, dim=1)) for i,conv in zip(inp_list, self.context_zqr_convs)]
+
+ coords0, coords1 = self.initialize_flow(net_list[0])
+ if depth_init is not None:
+ coords1 = coords1 + depth_init
+
+ if self.training:
+ low_resolution_init = [self.clamp(depth_init[:,:1] * self.regress_scale + self.max_val), depth_init[:,1:2], norm_normalize(depth_init[:,2:].clone())]
+ init_depth = upflow4(depth_init)
+ flow_predictions = [self.clamp(init_depth[:,:1] * self.regress_scale + self.max_val)]
+ conf_predictions = [init_depth[:,1:2]]
+ normal_outs = [norm_normalize(init_depth[:,2:].clone())]
+
+ else:
+ flow_predictions = []
+ conf_predictions = []
+ samples_pred_list = []
+ coord_list = []
+ normal_outs = []
+ low_resolution_init = []
+
+ for itr in range(self.iters):
+ # coords1 = coords1.detach()
+ flow = coords1 - coords0
+ if self.n_gru_layers == 3 and self.slow_fast_gru: # Update low-res GRU
+ net_list = self.update_block(net_list, inp_list, iter32=True, iter16=False, iter08=False, update=False)
+ if self.n_gru_layers >= 2 and self.slow_fast_gru:# Update low-res GRU and mid-res GRU
+ net_list = self.update_block(net_list, inp_list, iter32=self.n_gru_layers==3, iter16=True, iter08=False, update=False)
+ net_list, up_mask, delta_flow = self.update_block(net_list, inp_list, None, flow, iter32=self.n_gru_layers==3, iter16=self.n_gru_layers>=2)
+
+ # F(t+1) = F(t) + \Delta(t)
+ coords1 = coords1 + delta_flow
+
+ # We do not need to upsample or output intermediate results in test_mode
+ #if (not self.training) and itr < self.iters-1:
+ #continue
+
+ # upsample predictions
+ if up_mask is None:
+ flow_up = self.upsample(coords1-coords0, 4)
+ else:
+ flow_up = self.upsample_flow(coords1 - coords0, up_mask)
+ # flow_up = self.upsample(coords1-coords0, 4)
+
+ flow_predictions.append(self.clamp(flow_up[:,:1] * self.regress_scale + self.max_val))
+ conf_predictions.append(flow_up[:,1:2])
+ normal_outs.append(norm_normalize(flow_up[:,2:].clone()))
+
+ outputs=dict(
+ prediction=flow_predictions[-1],
+ predictions_list=flow_predictions,
+ confidence=conf_predictions[-1],
+ confidence_list=conf_predictions,
+ pred_logit=None,
+ # samples_pred_list=samples_pred_list,
+ # coord_list=coord_list,
+ prediction_normal=normal_outs[-1],
+ normal_out_list=normal_outs,
+ low_resolution_init=low_resolution_init,
+ )
+
+ return outputs
+
+
+if __name__ == "__main__":
+ try:
+ from mmcv.utils import Config
+ except:
+ from mmengine import Config
+ cfg = Config.fromfile('/cpfs01/shared/public/users/mu.hu/monodepth/mono/configs/RAFTDecoder/vit.raft.full2t.py')
+ cfg.model.decode_head.in_channels = [384, 384, 384, 384]
+ cfg.model.decode_head.feature_channels = [96, 192, 384, 768]
+ cfg.model.decode_head.decoder_channels = [48, 96, 192, 384, 384]
+ cfg.model.decode_head.hidden_channels = [48, 48, 48, 48, 48]
+ cfg.model.decode_head.up_scale = 7
+
+ # cfg.model.decode_head.use_cls_token = True
+ # vit_feature = [[torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \
+ # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \
+ # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \
+ # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()]]
+
+ cfg.model.decode_head.use_cls_token = True
+ cfg.model.decode_head.num_register_tokens = 4
+ vit_feature = [[torch.rand((2, (74 * 74) + 5, 384)).cuda(),\
+ torch.rand((2, (74 * 74) + 5, 384)).cuda(), \
+ torch.rand((2, (74 * 74) + 5, 384)).cuda(), \
+ torch.rand((2, (74 * 74) + 5, 384)).cuda()], (2, 74, 74, 1036, 1036, 4)]
+
+ decoder = RAFTDepthNormalDPT5(cfg).cuda()
+ output = decoder(vit_feature)
+ temp = 1
+
+
+
+
diff --git a/training/mono/model/decode_heads/__init__.py b/training/mono/model/decode_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2453f91e124ca62437be8b9d3b7e270ceae34384
--- /dev/null
+++ b/training/mono/model/decode_heads/__init__.py
@@ -0,0 +1,4 @@
+from .RAFTDepthNormalDPTDecoder5 import RAFTDepthNormalDPT5
+
+__all__=['RAFTDepthNormalDPT5'
+]
\ No newline at end of file
diff --git a/training/mono/model/losses/AdabinsLoss.py b/training/mono/model/losses/AdabinsLoss.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dbed8db40a1d8062c3d9872581fc83339f4bbef
--- /dev/null
+++ b/training/mono/model/losses/AdabinsLoss.py
@@ -0,0 +1,101 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn.utils.rnn import pad_sequence
+#from pytorch3d.loss import chamfer_distance
+
+class AdabinsLoss(nn.Module):
+ """
+ Losses employed in Adabins.
+ """
+ def __init__(self, depth_normalize, variance_focus=0.85, loss_weight=1, out_channel=100, data_type=['stereo', 'lidar'], w_ce=False, w_chamber=False, **kwargs):
+ super(AdabinsLoss, self).__init__()
+ self.variance_focus = variance_focus
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+ #self.bins_num = out_channel
+ #self.cel = nn.CrossEntropyLoss(ignore_index=self.bins_num + 1)
+ self.depth_min = depth_normalize[0]
+ self.depth_max = depth_normalize[1]
+ self.w_ce = w_ce
+ self.eps = 1e-6
+
+ def silog_loss(self, prediction, target, mask):
+ d = torch.log(prediction[mask]) - torch.log(target[mask])
+ d_square_mean = torch.sum(d ** 2) / (d.numel() + self.eps)
+ d_mean = torch.sum(d) / (d.numel() + self.eps)
+ loss = torch.sqrt(d_square_mean - self.variance_focus * (d_mean ** 2))
+ return loss
+
+ def chamfer_distance_loss(self, bins, target_depth_maps, mask):
+ bin_centers = 0.5 * (bins[:, 1:] + bins[:, :-1])
+ n, p = bin_centers.shape
+ input_points = bin_centers.view(n, p, 1) # .shape = n, p, 1
+ # n, c, h, w = target_depth_maps.shape
+
+ target_points = target_depth_maps.flatten(1) # n, hwc
+ #mask = target_points.ge(1e-3) # only valid ground truth points
+ target_points = [p[m] for p, m in zip(target_depth_maps, mask)]
+ target_lengths = torch.Tensor([len(t) for t in target_points], dtype=torch.long, device="cuda")
+ target_points = pad_sequence(target_points, batch_first=True).unsqueeze(2) # .shape = n, T, 1
+
+ loss, _ = chamfer_distance(x=input_points, y=target_points, y_lengths=target_lengths)
+ return loss
+
+ # def depth_to_bins(self, depth, mask, depth_edges, size_limite=(512, 960)):
+ # """
+ # Discretize depth into depth bins. Predefined bins edges are provided.
+ # Mark invalid padding area as bins_num + 1
+ # Args:
+ # @depth: 1-channel depth, [B, 1, h, w]
+ # return: depth bins [B, C, h, w]
+ # """
+ # def _depth_to_bins_block_(depth, mask, depth_edges):
+ # bins_id = torch.sum(depth_edges[:, None, None, None, :] < torch.abs(depth)[:, :, :, :, None], dim=-1)
+ # bins_id = bins_id - 1
+ # invalid_mask = ~mask
+ # mask_lower = (depth <= self.depth_min)
+ # mask_higher = (depth >= self.depth_max)
+
+ # bins_id[mask_lower] = 0
+ # bins_id[mask_higher] = self.bins_num - 1
+ # bins_id[bins_id == self.bins_num] = self.bins_num - 1
+
+ # bins_id[invalid_mask] = self.bins_num + 1
+ # return bins_id
+ # # _, _, H, W = depth.shape
+ # # bins = mask.clone().long()
+ # # h_blocks = np.ceil(H / size_limite[0]).astype(np.int)
+ # # w_blocks = np.ceil(W/ size_limite[1]).astype(np.int)
+ # # for i in range(h_blocks):
+ # # for j in range(w_blocks):
+ # # h_start = i*size_limite[0]
+ # # h_end_proposal = (i + 1) * size_limite[0]
+ # # h_end = h_end_proposal if h_end_proposal < H else H
+ # # w_start = j*size_limite[1]
+ # # w_end_proposal = (j + 1) * size_limite[1]
+ # # w_end = w_end_proposal if w_end_proposal < W else W
+ # # bins_ij = _depth_to_bins_block_(
+ # # depth[:, :, h_start:h_end, w_start:w_end],
+ # # mask[:, :, h_start:h_end, w_start:w_end],
+ # # depth_edges
+ # # )
+ # # bins[:, :, h_start:h_end, w_start:w_end] = bins_ij
+ # bins = _depth_to_bins_block_(depth, mask, depth_edges)
+ # return bins
+
+ # def ce_loss(self, pred_logit, target, mask, bins_edges):
+ # target_depth_bins = self.depth_to_bins(target, mask, bins_edges)
+ # loss = self.cel(pred_logit, target_depth_bins.squeeze().long())
+ # return loss
+
+
+ def forward(self, prediction, target, bins_edges, mask=None, **kwargs):
+ silog_loss = self.silog_loss(prediction=prediction, target=target, mask=mask)
+ #cf_loss = self.chamfer_distance_loss(bins=bins_edges, target_depth_maps=target, mask=mask)
+ loss = silog_loss * 10 #+ 0.1 * cf_loss
+ # if self.w_ce:
+ # loss = loss + self.ce_loss(kwargs['pred_logit'], target, mask, bins_edges)
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ raise RuntimeError(f'Adabins loss error, {loss}')
+ return loss * self.loss_weight
\ No newline at end of file
diff --git a/training/mono/model/losses/ConfidenceGuideLoss.py b/training/mono/model/losses/ConfidenceGuideLoss.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c1d9cc4829e44423850826a3ef5bccfc7a49835
--- /dev/null
+++ b/training/mono/model/losses/ConfidenceGuideLoss.py
@@ -0,0 +1,54 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class ConfidenceGuideLoss(nn.Module):
+ """
+ confidence guide depth loss.
+ """
+ def __init__(self, loss_weight=1, data_type=['stereo', 'lidar', 'denselidar'], loss_gamma=0.9, conf_loss=True, **kwargs):
+ super(ConfidenceGuideLoss, self).__init__()
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+ self.eps = 1e-6
+ self.loss_gamma = loss_gamma
+ self.conf_loss = conf_loss
+
+ def forward(self, samples_pred_list, target, coord_list, mask=None, **kwargs):
+ loss = 0.0
+ n_predictions = len(samples_pred_list)
+ for i, (pred, coord) in enumerate(zip(samples_pred_list, coord_list)):
+ # coord: B, 1, N, 2
+ # pred: B, 2, N
+ gt_depth_ = F.grid_sample(target, coord, mode='nearest', align_corners=True) # (B, 1, 1, N)
+ gt_depth_mask_ = F.grid_sample(mask.float(), coord, mode='nearest', align_corners=True) # (B, 1, 1, N)
+ gt_depth_ = gt_depth_[:, :, 0, :]
+ gt_depth_mask_ = gt_depth_mask_[:, :, 0, :] > 0.5
+
+ pred_depth, pred_conf = pred[:, :1, :], pred[:, 1:, :]
+
+ # We adjust the loss_gamma so it is consistent for any number of RAFT-Stereo iterations
+ adjusted_loss_gamma = self.loss_gamma**(15/(n_predictions - 1))
+ i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
+
+ # depth L1 loss
+ diff = torch.abs(pred_depth - gt_depth_) * gt_depth_mask_
+ curr_loss = torch.sum(diff) / (torch.sum(gt_depth_mask_) + self.eps)
+ if torch.isnan(curr_loss).item() | torch.isinf(curr_loss).item():
+ curr_loss = 0 * torch.sum(pred_depth)
+ print(f'GRUSequenceLoss-depth NAN error, {loss}')
+
+ # confidence L1 loss
+ conf_loss = 0.0
+ if self.conf_loss:
+ conf_mask = torch.abs(gt_depth_ - pred_depth) < gt_depth_
+ conf_mask = conf_mask & gt_depth_mask_
+ gt_confidence = (1 - torch.abs((pred_depth - gt_depth_) / gt_depth_)) * conf_mask
+ conf_loss = torch.sum(torch.abs(pred_conf - gt_confidence) * conf_mask) / (torch.sum(conf_mask) + self.eps)
+ if torch.isnan(conf_loss).item() | torch.isinf(conf_loss).item():
+ conf_loss = 0 * torch.sum(pred_conf)
+ print(f'GRUSequenceLoss-confidence NAN error, {conf_loss}')
+
+ loss += (conf_loss + curr_loss) * i_weight
+
+ return loss * self.loss_weight
\ No newline at end of file
diff --git a/training/mono/model/losses/ConfidenceLoss.py b/training/mono/model/losses/ConfidenceLoss.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8ed6d7d7eadf0e1c1be009f88335a04a04e3d2d
--- /dev/null
+++ b/training/mono/model/losses/ConfidenceLoss.py
@@ -0,0 +1,22 @@
+import torch
+import torch.nn as nn
+
+class ConfidenceLoss(nn.Module):
+ """
+ confidence loss.
+ """
+ def __init__(self, loss_weight=1, data_type=['stereo', 'lidar', 'denselidar'], **kwargs):
+ super(ConfidenceLoss, self).__init__()
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+ self.eps = 1e-6
+
+ def forward(self, prediction, target, confidence, mask=None, **kwargs):
+ conf_mask = torch.abs(target - prediction) < target
+ conf_mask = conf_mask & mask
+ gt_confidence = (1 - torch.abs((prediction - target) / target)) * conf_mask
+ loss = torch.sum(torch.abs(confidence - gt_confidence) * conf_mask) / (torch.sum(conf_mask) + self.eps)
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ loss = 0 * torch.sum(confidence)
+ print(f'ConfidenceLoss NAN error, {loss}')
+ return loss * self.loss_weight
\ No newline at end of file
diff --git a/training/mono/model/losses/GRUSequenceLoss.py b/training/mono/model/losses/GRUSequenceLoss.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d829f2874dc007260be349955b7bea62debd8ae
--- /dev/null
+++ b/training/mono/model/losses/GRUSequenceLoss.py
@@ -0,0 +1,181 @@
+import torch
+import torch.nn as nn
+
+class GRUSequenceLoss(nn.Module):
+ """
+ Loss function defined over sequence of depth predictions
+ """
+ def __init__(self, loss_weight=1, data_type=['lidar', 'denselidar', 'stereo', 'denselidar_syn'], loss_gamma=0.9, silog=False, stereo_sup=0.001, stereo_dataset=['KITTI', 'NYU'], **kwargs):
+ super(GRUSequenceLoss, self).__init__()
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+ self.eps = 1e-6
+ self.loss_gamma = loss_gamma
+ self.silog = silog
+ self.variance_focus = 0.5
+ self.stereo_sup = stereo_sup
+ self.stereo_dataset = stereo_dataset
+
+ # assert stereo_mode in ['stereo', 'self_sup']
+ # self.stereo_mode = stereo_mode
+ # self.stereo_max = stereo_max
+
+ def silog_loss(self, prediction, target, mask):
+ mask = mask & (prediction > 0.01) & (target > 0.01)
+ d = torch.log(prediction[mask]) - torch.log(target[mask])
+ # d_square_mean = torch.sum(d ** 2) / (d.numel() + self.eps)
+ # d_mean = torch.sum(d) / (d.numel() + self.eps)
+ # loss = d_square_mean - self.variance_focus * (d_mean ** 2)
+ loss = torch.sum(torch.abs(d)) / (d.numel() + self.eps)
+ print("new log l1 loss")
+ return loss
+
+ def conf_loss(self, confidence, prediction, target, mask):
+ conf_mask = torch.abs(target - prediction) < target
+ conf_mask = conf_mask & mask
+ gt_confidence = (1 - torch.abs((prediction - target) / target)) * conf_mask
+ loss = torch.sum(torch.abs(confidence - gt_confidence) * conf_mask) / (torch.sum(conf_mask) + self.eps)
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ print(f'GRUSequenceLoss-confidence NAN error, {loss}')
+ loss = 0 * torch.sum(confidence)
+ return loss
+
+ def forward(self, predictions_list, target, stereo_depth, confidence_list=None, mask=None, **kwargs):
+ device = target.device
+
+ batches_dataset = kwargs['dataset']
+ self.batch_with_stereo = torch.tensor([1 if batch_dataset in self.stereo_dataset else 0 \
+ for batch_dataset in batches_dataset], device=device)[:,None,None,None]
+
+ n_predictions = len(predictions_list)
+ assert n_predictions >= 1
+ loss = 0.0
+
+ for i, prediction in enumerate(predictions_list):
+ # if self.stereo_mode == 'self_sup' and self.stereo_sup > 1e-8:
+ # B, C, H, W = target.shape
+ # prediction_nan = prediction.clone().detach()
+ # target_nan = target.clone()
+ # prediction_nan[~mask] = float('nan')
+ # target_nan[~mask] = float('nan')
+ # gt_median = target_nan.reshape((B, C,-1)).nanmedian(2)[0][:, :, None, None]
+
+ # pred_median = prediction_nan.reshape((B, C,-1)).nanmedian(2)[0][:, :, None, None]
+ # scale = gt_median / (pred_median + 1e-8)
+
+ # stereo_depth = (0.0 * stereo_depth + scale * prediction * (prediction < (self.stereo_max - 1)) + \
+ # prediction * (prediction > (self.stereo_max - 1))).detach()
+
+ # We adjust the loss_gamma so it is consistent for any number of RAFT-Stereo iterations
+ adjusted_loss_gamma = self.loss_gamma**(15/(n_predictions - 1))
+ i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
+
+ # depth L1 loss
+ if self.silog and mask.sum() > 0:
+ curr_loss = self.silog_loss(prediction, target, mask)
+ else:
+ diff = torch.abs(prediction - target) * mask
+ #diff = diff + diff * diff * 1.0
+ curr_loss = torch.sum(diff) / (torch.sum(mask) + self.eps)
+ if torch.isnan(curr_loss).item() | torch.isinf(curr_loss).item():
+ print(f'GRUSequenceLoss-depth NAN error, {curr_loss}')
+ curr_loss = 0 * torch.sum(prediction)
+
+ # confidence L1 loss
+ conf_loss = 0
+ if confidence_list is not None:
+ conf_loss = self.conf_loss(confidence_list[i], prediction, target, mask)
+
+ # stereo depth loss
+ mask_stereo = 1 + torch.nn.functional.max_pool2d(\
+ - torch.nn.functional.max_pool2d(mask * 1.0, 3, stride=1, padding=1, dilation=1), 3, stride=1, padding=1, dilation=1)
+
+ stereo_diff = torch.abs(prediction - stereo_depth) * mask_stereo
+ #stereo_diff = stereo_diff + stereo_diff * stereo_diff * 1.0
+ stereo_depth_loss = torch.sum(self.batch_with_stereo * stereo_diff * mask_stereo) / (torch.sum(mask_stereo) + self.eps)
+ stereo_depth_loss = self.stereo_sup * stereo_depth_loss
+
+ loss += (conf_loss + curr_loss + stereo_depth_loss) * i_weight
+ #raise RuntimeError(f'Silog error, {loss}, d_square_mean: {d_square_mean}, d_mean: {d_mean}')
+ return loss * self.loss_weight
+
+# import torch
+# import torch.nn as nn
+
+# class GRUSequenceLoss(nn.Module):
+# """
+# Loss function defined over sequence of depth predictions
+# """
+# def __init__(self, loss_weight=1, data_type=['lidar', 'denselidar', 'stereo', 'denselidar_syn'], loss_gamma=0.9, silog=False, stereo_sup=0.001, stereo_dataset=['BigData'], **kwargs):
+# super(GRUSequenceLoss, self).__init__()
+# self.loss_weight = loss_weight
+# self.data_type = data_type
+# self.eps = 1e-6
+# self.loss_gamma = loss_gamma
+# self.silog = silog
+# self.variance_focus = 0.5
+# self.stereo_sup = stereo_sup
+# self.stereo_dataset = stereo_dataset
+
+# def silog_loss(self, prediction, target, mask):
+# mask = mask & (prediction > 0.01) & (target > 0.01)
+# d = torch.log(prediction[mask]) - torch.log(target[mask])
+# # d_square_mean = torch.sum(d ** 2) / (d.numel() + self.eps)
+# # d_mean = torch.sum(d) / (d.numel() + self.eps)
+# # loss = d_square_mean - self.variance_focus * (d_mean ** 2)
+# loss = torch.sum(torch.abs(d)) / (d.numel() + self.eps)
+# print("new log l1 loss")
+# return loss
+
+# def conf_loss(self, confidence, prediction, target, mask):
+# conf_mask = torch.abs(target - prediction) < target
+# conf_mask = conf_mask & mask
+# gt_confidence = (1 - torch.abs((prediction - target) / target)) * conf_mask
+# loss = torch.sum(torch.abs(confidence - gt_confidence) * conf_mask) / (torch.sum(conf_mask) + self.eps)
+# if torch.isnan(loss).item() | torch.isinf(loss).item():
+# print(f'GRUSequenceLoss-confidence NAN error, {loss}')
+# loss = 0 * torch.sum(confidence)
+# return loss
+
+# def forward(self, predictions_list, target, stereo_depth, confidence_list=None, mask=None, **kwargs):
+# device = target.device
+
+# batches_dataset = kwargs['dataset']
+# self.batch_with_stereo = torch.tensor([1 if batch_dataset in self.stereo_dataset else 0 \
+# for batch_dataset in batches_dataset], device=device)[:,None,None,None]
+
+# n_predictions = len(predictions_list)
+# assert n_predictions >= 1
+# loss = 0.0
+
+# for i, prediction in enumerate(predictions_list):
+# # We adjust the loss_gamma so it is consistent for any number of RAFT-Stereo iterations
+# adjusted_loss_gamma = self.loss_gamma**(15/(n_predictions - 1))
+# i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
+
+# # depth L1 loss
+# if self.silog and mask.sum() > 0:
+# curr_loss = self.silog_loss(prediction, target, mask)
+# else:
+# diff = torch.abs(prediction - target) * mask
+# curr_loss = torch.sum(diff) / (torch.sum(mask) + self.eps)
+# if torch.isnan(curr_loss).item() | torch.isinf(curr_loss).item():
+# print(f'GRUSequenceLoss-depth NAN error, {curr_loss}')
+# curr_loss = 0 * torch.sum(prediction)
+
+# # confidence L1 loss
+# conf_loss = 0
+# if confidence_list is not None:
+# conf_loss = self.conf_loss(confidence_list[i], prediction, target, mask)
+
+# # stereo depth loss
+# mask_stereo = 1 + torch.nn.functional.max_pool2d(\
+# - torch.nn.functional.max_pool2d(mask * 1.0, 5, stride=1, padding=2, dilation=1), 5, stride=1, padding=2, dilation=1)
+
+# stereo_diff = torch.abs(prediction - stereo_depth) * mask_stereo
+# stereo_depth_loss = torch.sum(self.batch_with_stereo * stereo_diff * mask_stereo) / (torch.sum(mask_stereo) + self.eps)
+# stereo_depth_loss = self.stereo_sup * stereo_depth_loss
+
+# loss += (conf_loss + curr_loss + stereo_depth_loss) * i_weight
+# #raise RuntimeError(f'Silog error, {loss}, d_square_mean: {d_square_mean}, d_mean: {d_mean}')
+# return loss * self.loss_weight
\ No newline at end of file
diff --git a/training/mono/model/losses/Gradient.py b/training/mono/model/losses/Gradient.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b730917acc4dde1b74000b40e2a2aceb81d2aed
--- /dev/null
+++ b/training/mono/model/losses/Gradient.py
@@ -0,0 +1,121 @@
+import torch
+import torch.nn as nn
+
+EPSILON = 1e-6
+"""
+ # @Zhengqi Li version.
+ def GradientLoss(self, log_prediction_d, mask, log_gt):
+ log_d_diff = log_prediction_d - log_gt
+
+ v_gradient = torch.abs(log_d_diff[:, :-2, :] - log_d_diff[:, 2:, :])
+ v_mask = torch.mul(mask[:, :-2, :], mask[:, 2:, :])
+ v_gradient = torch.mul(v_gradient, v_mask)
+
+ h_gradient = torch.abs(log_d_diff[:, :, :-2] - log_d_diff[:, :, 2:])
+ h_mask = torch.mul(mask[:, :, :-2], mask[:, :, 2:])
+ h_gradient = torch.mul(h_gradient, h_mask)
+
+ N = torch.sum(h_mask) + torch.sum(v_mask) + EPSILON
+
+ gradient_loss = torch.sum(h_gradient) + torch.sum(v_gradient)
+ gradient_loss = gradient_loss / N
+
+ return gradient_loss
+"""
+def gradient_log_loss(log_prediction_d, log_gt, mask):
+ log_d_diff = log_prediction_d - log_gt
+
+ v_gradient = torch.abs(log_d_diff[:, :, :-2, :] - log_d_diff[:, :, 2:, :])
+ v_mask = torch.mul(mask[:, :, :-2, :], mask[:, :, 2:, :])
+ v_gradient = torch.mul(v_gradient, v_mask)
+
+ h_gradient = torch.abs(log_d_diff[:, :, :, :-2] - log_d_diff[:, :, :, 2:])
+ h_mask = torch.mul(mask[:, :, :, :-2], mask[:, :, :, 2:])
+ h_gradient = torch.mul(h_gradient, h_mask)
+
+ N = torch.sum(h_mask) + torch.sum(v_mask) + EPSILON
+
+ gradient_loss = torch.sum(h_gradient) + torch.sum(v_gradient)
+ gradient_loss = gradient_loss / N
+
+ return gradient_loss
+
+class GradientLoss_Li(nn.Module):
+ def __init__(self, scale_num=1, loss_weight=1, data_type = ['lidar', 'stereo'], **kwargs):
+ super(GradientLoss_Li, self).__init__()
+ self.__scales = scale_num
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+ self.eps = 1e-6
+
+ def forward(self, prediction, target, mask, **kwargs):
+ total = 0
+ target_trans = target + (~mask) * 100
+ pred_log = torch.log(prediction)
+ gt_log = torch.log(target_trans)
+ for scale in range(self.__scales):
+ step = pow(2, scale)
+
+ total += gradient_log_loss(pred_log[:, ::step, ::step], gt_log[:, ::step, ::step], mask[:, ::step, ::step])
+ loss = total / self.__scales
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ raise RuntimeError(f'VNL error, {loss}')
+ return loss * self.loss_weight
+
+######################################################
+# Multi-scale gradient matching loss, @Ke Xian implementation.
+#####################################################
+def gradient_loss(prediction, target, mask):
+ M = torch.sum(mask, (1, 2))
+
+ diff = prediction - target
+ diff = torch.mul(mask, diff)
+
+ grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
+ mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
+ grad_x = torch.mul(mask_x, grad_x)
+
+ grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
+ mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
+ grad_y = torch.mul(mask_y, grad_y)
+
+ image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))
+ valid = M.nonzero()
+ if image_loss[valid].numel() > 0:
+ image_loss[valid] = image_loss[valid] / M[valid]
+ loss = torch.mean(image_loss)
+ else:
+ loss = 0 * torch.sum(prediction)
+
+ return loss
+
+
+class GradientLoss(nn.Module):
+ def __init__(self, scale_num=4, loss_weight=1, **kwargs):
+ super(GradientLoss, self).__init__()
+ self.__scales = scale_num
+ self.loss_weight = loss_weight
+ def forward(self, prediction, target, mask, **kwargs):
+ total = 0
+ for scale in range(self.__scales):
+ step = pow(2, scale)
+ total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step], mask[:, ::step, ::step])
+
+ return total * self.loss_weight
+
+
+if __name__ == '__main__':
+ import numpy as np
+ gradient = GradientLoss_Li(4)
+
+ pred_depth = np.random.random([2, 1, 480, 640])
+ gt_depth = np.ones_like(pred_depth) * (-1) #np.random.random([2, 1, 480, 640]) - 0.5 #
+ #gt_depth = np.abs(gt_depth)
+ intrinsic = [[100, 100, 200, 200], [200, 200, 300, 300]]
+
+ pred = torch.from_numpy(pred_depth).cuda()
+ gt = torch.from_numpy(gt_depth).cuda()
+ mask = gt > 0
+
+ loss = gradient(gt, gt, mask)
+ print(loss)
\ No newline at end of file
diff --git a/training/mono/model/losses/HDNL.py b/training/mono/model/losses/HDNL.py
new file mode 100644
index 0000000000000000000000000000000000000000..db2e95caf1f87e836581d41517e1db21935eda08
--- /dev/null
+++ b/training/mono/model/losses/HDNL.py
@@ -0,0 +1,95 @@
+import torch
+import torch.nn as nn
+
+class HDNLoss(nn.Module):
+ """
+ Hieratical depth normalization loss.
+ loss = MAE((d-median(d)/s - (d'-median(d'))/s'), s = mean(d- median(d))
+ """
+ def __init__(self, loss_weight=1, grid=3, data_type=['sfm', 'stereo', 'lidar'], **kwargs):
+ super(HDNLoss, self).__init__()
+ self.loss_weight = loss_weight
+ self.grid = grid
+ self.data_type = data_type
+
+ def get_hierachy_masks(self, grid, depth_gt, mask_valid):
+
+ batch_map_grid = []
+ for mask_index in range(depth_gt.shape[0]):
+ depth_map = depth_gt[mask_index]
+ valid_map = mask_valid[mask_index]
+
+ # print (depth_map[valid_map].view(-1).shape)
+ if depth_map[valid_map].numel() == 0:
+ map_grid_list = [valid_map for _ in range(2 ** (grid) - 1)]
+ else:
+ valid_values = depth_map[valid_map]
+
+ max_d = valid_values.max()
+ min_d = valid_values.min()
+
+ anchor_power = [(1 / 2) ** (i) for i in range(grid)]
+ anchor_power.reverse()
+
+ map_grid_list = []
+ for anchor in anchor_power:
+ # range
+ for i in range(int(1 / anchor)):
+ mask_new = (depth_map >= min_d + (max_d - min_d) * i * anchor) & (
+ depth_map < min_d + (max_d - min_d) * (i + 1) * anchor+1e-30)
+ # print (f'[{i*anchor},{(i+1)*anchor}]')
+ mask_new = mask_new & valid_map
+ map_grid_list.append(mask_new)
+ map_grid_list = torch.stack(map_grid_list, dim=0)
+ batch_map_grid.append(map_grid_list)
+ batch_map_grid = torch.stack(batch_map_grid, dim=1)
+ return batch_map_grid
+
+ def ssi_mae(self, prediction, target, mask_valid):
+ B, C, H, W = target.shape
+ prediction_nan = prediction.clone()
+ target_nan = target.clone()
+ prediction_nan[~mask_valid] = float('nan')
+ target_nan[~mask_valid] = float('nan')
+
+ valid_pixs = mask_valid.reshape((B, C,-1)).sum(dim=2, keepdims=True) + 1e-10
+ valid_pixs = valid_pixs[:, :, :, None]
+
+ gt_median = target_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) # [b,c,h,w]
+ gt_median[torch.isnan(gt_median)] = 0
+ gt_diff = (torch.abs(target - gt_median) * mask_valid).reshape((B, C, -1))
+ gt_s = gt_diff.sum(dim=2)[:, :, None, None] / valid_pixs
+ gt_trans = (target - gt_median) / (gt_s + 1e-8)
+
+ pred_median = prediction_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) # [b,c,h,w]
+ pred_median[torch.isnan(pred_median)] = 0
+ pred_diff = (torch.abs(prediction - pred_median) * mask_valid).reshape((B, C, -1))
+ pred_s = pred_diff.sum(dim=2)[:, :, None, None] / valid_pixs
+ pred_trans = (prediction - pred_median) / (pred_s + 1e-8)
+
+ loss = torch.sum(torch.abs(gt_trans - pred_trans)*mask_valid) / (torch.sum(mask_valid) + 1e-8)
+ return pred_trans, gt_trans, loss
+
+ def forward(self, prediction, target, mask=None, **kwargs):
+ """
+ Calculate loss.
+ """
+ B, C, H, W = target.shape
+ hierachy_masks = self.get_hierachy_masks(self.grid, target, mask)
+ hierachy_masks_shape = hierachy_masks.reshape(-1, C, H, W)
+ prediction_hie = prediction.unsqueeze(0).repeat(hierachy_masks.shape[0], 1, 1, 1, 1).reshape(-1, C, H, W)
+
+ target_hie = target.unsqueeze(0).repeat(hierachy_masks.shape[0], 1, 1, 1, 1).reshape(-1, C, H, W)
+
+ #_, _, loss = self.ssi_mae(prediction, target, mask)
+ _, _, loss = self.ssi_mae(prediction_hie, target_hie, hierachy_masks_shape)
+ return loss * self.loss_weight
+
+if __name__ == '__main__':
+ ssil = HDNLoss()
+ pred = torch.rand((2, 1, 256, 256)).cuda()
+ gt = torch.rand((2, 1, 256, 256)).cuda()#torch.zeros_like(pred).cuda() #
+ gt[:, :, 100:256, 0:100] = -1
+ mask = gt > 0
+ out = ssil(pred, gt, mask)
+ print(out)
diff --git a/training/mono/model/losses/HDNL_random.py b/training/mono/model/losses/HDNL_random.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0b40eb0d0652ceb012ae89bf78db8f2d763720a
--- /dev/null
+++ b/training/mono/model/losses/HDNL_random.py
@@ -0,0 +1,104 @@
+import torch
+import torch.nn as nn
+import numpy as np
+
+class HDNRandomLoss(nn.Module):
+ """
+ Hieratical depth normalization loss. Replace the original hieratical depth ranges with randomly sampled ranges.
+ loss = MAE((d-median(d)/s - (d'-median(d'))/s'), s = mean(d- median(d))
+ """
+ def __init__(self, loss_weight=1, random_num=32, data_type=['sfm', 'stereo', 'lidar', 'denselidar', 'denselidar_nometric', 'denselidar_syn'], norm_dataset=['Taskonomy', 'Matterport3D', 'Replica', 'Hypersim'], disable_dataset=['MapillaryPSD'], **kwargs):
+ super(HDNRandomLoss, self).__init__()
+ self.loss_weight = loss_weight
+ self.random_num = random_num
+ self.eps = 1e-6
+ self.data_type = data_type
+ self.disable_dataset = disable_dataset
+
+ def get_random_masks_for_batch(self, depth_gt: torch.Tensor, mask_valid: torch.Tensor)-> torch.Tensor:
+ valid_values = depth_gt[mask_valid]
+ max_d = valid_values.max().item() if valid_values.numel() > 0 else 0.0
+ min_d = valid_values.min().item() if valid_values.numel() > 0 else 0.0
+
+ sample_min_d = np.random.uniform(0, 0.75, self.random_num) * (max_d - min_d) + min_d
+ sample_max_d = np.random.uniform(sample_min_d + 0.1, 1-self.eps, self.random_num) * (max_d - min_d) + min_d
+
+ mask_new = [(depth_gt >= sample_min_d[i]) & (depth_gt < sample_max_d[i] + 1e-30) & mask_valid for i in range(self.random_num)]
+ mask_new = torch.stack(mask_new, dim=0).cuda() #[N, 1, H, W]
+ return mask_new
+
+ def ssi_mae(self, prediction, target, mask_valid):
+ B, C, H, W = target.shape
+ prediction_nan = prediction.clone().detach()
+ target_nan = target.clone()
+ prediction_nan[~mask_valid] = float('nan')
+ target_nan[~mask_valid] = float('nan')
+
+ valid_pixs = mask_valid.reshape((B, C,-1)).sum(dim=2, keepdims=True) + self.eps
+ valid_pixs = valid_pixs[:, :, :, None]
+
+ gt_median = target_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) # [b,c,h,w]
+ gt_median[torch.isnan(gt_median)] = 0
+ gt_diff = (torch.abs(target - gt_median) * mask_valid).reshape((B, C, -1))
+ gt_s = gt_diff.sum(dim=2)[:, :, None, None] / valid_pixs
+ gt_trans = (target - gt_median) / (gt_s + self.eps)
+
+ pred_median = prediction_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) # [b,c,h,w]
+ pred_median[torch.isnan(pred_median)] = 0
+ pred_diff = (torch.abs(prediction - pred_median) * mask_valid).reshape((B, C, -1))
+ pred_s = pred_diff.sum(dim=2)[:, :, None, None] / valid_pixs
+ pred_trans = (prediction - pred_median) / (pred_s + self.eps)
+
+ loss_sum = torch.sum(torch.abs(gt_trans - pred_trans)*mask_valid)
+ return loss_sum
+
+ def forward(self, prediction, target, mask=None, **kwargs):
+ """
+ Calculate loss.
+ """
+ B, C, H, W = target.shape
+
+ loss = 0.0
+ valid_pix = 0.0
+
+ device = target.device
+
+ batches_dataset = kwargs['dataset']
+ self.batch_valid = torch.tensor([1 if batch_dataset not in self.disable_dataset else 0 \
+ for batch_dataset in batches_dataset], device=device)[:,None,None,None]
+
+ batch_limit = 4
+ loops = int(np.ceil(self.random_num / batch_limit))
+ for i in range(B):
+ mask_i = mask[i, ...] #[1, H, W]
+
+ if self.batch_valid[i, ...] < 0.5:
+ loss += 0 * torch.sum(prediction[i, ...])
+ valid_pix += 0 * torch.sum(mask_i)
+ continue
+
+ pred_i = prediction[i, ...].unsqueeze(0).repeat(batch_limit, 1, 1, 1)
+ target_i = target[i, ...].unsqueeze(0).repeat(batch_limit, 1, 1, 1)
+ mask_random_drange = self.get_random_masks_for_batch(target[i, ...], mask_i) # [N, 1, H, W]
+ for j in range(loops):
+ mask_random_loopi = mask_random_drange[j*batch_limit:(j+1)*batch_limit, ...]
+ loss += self.ssi_mae(
+ prediction=pred_i[:mask_random_loopi.shape[0], ...],
+ target=target_i[:mask_random_loopi.shape[0], ...],
+ mask_valid=mask_random_loopi)
+ valid_pix += torch.sum(mask_random_loopi)
+
+ loss = loss / (valid_pix + self.eps)
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ loss = 0 * torch.sum(prediction)
+ print(f'HDNL NAN error, {loss}, valid pix: {valid_pix}')
+ return loss * self.loss_weight
+
+if __name__ == '__main__':
+ ssil = HDNRandomLoss()
+ pred = torch.rand((2, 1, 256, 256)).cuda()
+ gt = - torch.rand((2, 1, 256, 256)).cuda()#torch.zeros_like(pred).cuda() #
+ gt[:, :, 100:256, 0:100] = -1
+ mask = gt > 0
+ out = ssil(pred, gt, mask)
+ print(out)
diff --git a/training/mono/model/losses/HDSNL.py b/training/mono/model/losses/HDSNL.py
new file mode 100644
index 0000000000000000000000000000000000000000..250671b5ad52faf8f3d1e5bac41ad898ca3967a2
--- /dev/null
+++ b/training/mono/model/losses/HDSNL.py
@@ -0,0 +1,82 @@
+import torch
+import torch.nn as nn
+
+class HDSNLoss(nn.Module):
+ """
+ Hieratical depth spatial normalization loss.
+ loss = MAE((d-median(d)/s - (d'-median(d'))/s'), s = mean(d- median(d))
+ """
+ def __init__(self, loss_weight=1.0, grid=3, data_type=['sfm', 'stereo', 'lidar'], **kwargs):
+ super(HDSNLoss, self).__init__()
+ self.loss_weight = loss_weight
+ self.grid = grid
+ self.data_type = data_type
+
+ def get_hierachy_masks(self, batch, image_size, mask):
+ height, width = image_size
+ anchor_power = [(1 / 2) ** (i) for i in range(self.grid)]
+ anchor_power.reverse()
+
+ map_grid_list = []
+ for anchor in anchor_power: # e.g. 1/8
+ for h in range(int(1 / anchor)):
+ for w in range(int(1 / anchor)):
+ mask_new = torch.zeros((batch, 1, height, width), dtype=torch.bool).cuda()
+ mask_new[:, :, int(h * anchor * height):int((h + 1) * anchor * height),
+ int(w * anchor * width):int((w + 1) * anchor * width)] = True
+ mask_new = mask & mask_new
+ map_grid_list.append(mask_new)
+ batch_map_grid=torch.stack(map_grid_list,dim=0) # [N, B, 1, H, W]
+
+ return batch_map_grid
+
+ def ssi_mae(self, prediction, target, mask_valid):
+ B, C, H, W = target.shape
+ prediction_nan = prediction.clone()
+ target_nan = target.clone()
+ prediction_nan[~mask_valid] = float('nan')
+ target_nan[~mask_valid] = float('nan')
+
+ valid_pixs = mask_valid.reshape((B, C,-1)).sum(dim=2, keepdims=True) + 1e-10
+ valid_pixs = valid_pixs[:, :, :, None]
+
+ gt_median = target_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) # [b,c,h,w]
+ gt_median[torch.isnan(gt_median)] = 0
+ gt_diff = (torch.abs(target - gt_median) * mask_valid).reshape((B, C, -1))
+ gt_s = gt_diff.sum(dim=2)[:, :, None, None] / valid_pixs
+ gt_trans = (target - gt_median) / (gt_s + 1e-8)
+
+ pred_median = prediction_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) # [b,c,h,w]
+ pred_median[torch.isnan(pred_median)] = 0
+ pred_diff = (torch.abs(prediction - pred_median) * mask_valid).reshape((B, C, -1))
+ pred_s = pred_diff.sum(dim=2)[:, :, None, None] / valid_pixs
+ pred_trans = (prediction - pred_median) / (pred_s + 1e-8)
+
+ loss = torch.sum(torch.abs(gt_trans - pred_trans)*mask_valid) / (torch.sum(mask_valid) + 1e-8)
+ return pred_trans, gt_trans, loss
+
+ def forward(self, prediction, target, mask=None, **kwargs):
+ """
+ Calculate loss.
+ """
+ B, C, H, W = target.shape
+ hierachy_masks = self.get_hierachy_masks(B, (H, W), mask) # [N, B, 1, H, W]
+ hierachy_masks_shape = hierachy_masks.reshape(-1, C, H, W)
+ prediction_hie = prediction.unsqueeze(0).repeat(hierachy_masks.shape[0], 1, 1, 1, 1).reshape(-1, C, H, W)
+
+ target_hie = target.unsqueeze(0).repeat(hierachy_masks.shape[0], 1, 1, 1, 1).reshape(-1, C, H, W)
+
+ #_, _, loss = self.ssi_mae(prediction, target, mask)
+ _, _, loss = self.ssi_mae(prediction_hie, target_hie, hierachy_masks_shape)
+ return loss * self.loss_weight
+
+if __name__ == '__main__':
+ torch.manual_seed(1)
+ torch.cuda.manual_seed_all(1)
+ ssil = HDSNLoss()
+ pred = torch.rand((2, 1, 256, 256)).cuda()
+ gt = torch.rand((2, 1, 256, 256)).cuda()#torch.zeros_like(pred).cuda() #
+ gt[:, :, 100:256, 0:100] = -1
+ mask = gt > 0
+ out = ssil(pred, gt, mask)
+ print(out)
diff --git a/training/mono/model/losses/HDSNL_random.py b/training/mono/model/losses/HDSNL_random.py
new file mode 100644
index 0000000000000000000000000000000000000000..28dde298f3e3c44a1980cc513a2f8e191d5de2bb
--- /dev/null
+++ b/training/mono/model/losses/HDSNL_random.py
@@ -0,0 +1,230 @@
+import torch
+import torch.nn as nn
+import numpy as np
+#from numba import jit
+
+class HDSNRandomLoss(nn.Module):
+ """
+ Hieratical depth spatial normalization loss.
+ Replace the original grid masks with the random created masks.
+ loss = MAE((d-median(d)/s - (d'-median(d'))/s'), s = mean(d- median(d))
+ """
+ def __init__(self, loss_weight=1.0, random_num=32, data_type=['sfm', 'stereo', 'lidar', 'denselidar', 'denselidar_nometric','denselidar_syn'], disable_dataset=['MapillaryPSD'], sky_id=142, batch_limit=8, **kwargs):
+ super(HDSNRandomLoss, self).__init__()
+ self.loss_weight = loss_weight
+ self.random_num = random_num
+ self.data_type = data_type
+ self.sky_id = sky_id
+ self.batch_limit = batch_limit
+ self.eps = 1e-6
+ self.disable_dataset = disable_dataset
+
+ def get_random_masks_for_batch(self, image_size: list)-> torch.Tensor:
+ height, width = image_size
+ crop_h_min = int(0.125 * height)
+ crop_h_max = int(0.5 * height)
+ crop_w_min = int(0.125 * width)
+ crop_w_max = int(0.5 * width)
+ h_max = height - crop_h_min
+ w_max = width - crop_w_min
+ crop_height = np.random.choice(np.arange(crop_h_min, crop_h_max), self.random_num, replace=False)
+ crop_width = np.random.choice(np.arange(crop_w_min, crop_w_max), self.random_num, replace=False)
+ crop_y = np.random.choice(h_max, self.random_num, replace=False)
+ crop_x = np.random.choice(w_max, self.random_num, replace=False)
+ crop_y_end = crop_height + crop_y
+ crop_y_end[crop_y_end>=height] = height
+ crop_x_end = crop_width + crop_x
+ crop_x_end[crop_x_end>=width] = width
+
+ mask_new = torch.zeros((self.random_num, height, width), dtype=torch.bool, device="cuda") #.cuda() #[N, H, W]
+ for i in range(self.random_num):
+ mask_new[i, crop_y[i]:crop_y_end[i], crop_x[i]:crop_x_end[i]] = True
+
+ return mask_new
+ #return crop_y, crop_y_end, crop_x, crop_x_end
+
+ def reorder_sem_masks(self, sem_label):
+ # reorder the semantic mask of a batch
+ assert sem_label.ndim == 3
+ semantic_ids = torch.unique(sem_label[(sem_label>0) & (sem_label != self.sky_id)])
+ sem_masks = [sem_label == id for id in semantic_ids]
+ if len(sem_masks) == 0:
+ # no valid semantic labels
+ out = sem_label > 0
+ return out
+
+ sem_masks = torch.cat(sem_masks, dim=0)
+ mask_batch = torch.sum(sem_masks.reshape(sem_masks.shape[0], -1), dim=1) > 500
+ sem_masks = sem_masks[mask_batch]
+ if sem_masks.shape[0] > self.random_num:
+ balance_samples = np.random.choice(sem_masks.shape[0], self.random_num, replace=False)
+ sem_masks = sem_masks[balance_samples, ...]
+
+ if sem_masks.shape[0] == 0:
+ # no valid semantic labels
+ out = sem_label > 0
+ return out
+
+ if sem_masks.ndim == 2:
+ sem_masks = sem_masks[None, :, :]
+ return sem_masks
+
+ def ssi_mae(self, prediction, target, mask_valid):
+ B, C, H, W = target.shape
+ prediction_nan = prediction.clone().detach()
+ target_nan = target.clone()
+ prediction_nan[~mask_valid] = float('nan')
+ target_nan[~mask_valid] = float('nan')
+
+ valid_pixs = mask_valid.reshape((B, C,-1)).sum(dim=2, keepdims=True) + 1e-10
+ valid_pixs = valid_pixs[:, :, :, None]
+
+ gt_median = target_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) # [b,c,h,w]
+ gt_median[torch.isnan(gt_median)] = 0
+ gt_diff = (torch.abs(target - gt_median) ).reshape((B, C, -1))
+ gt_s = gt_diff.sum(dim=2)[:, :, None, None] / valid_pixs
+ gt_trans = (target - gt_median) / (gt_s + self.eps)
+
+ pred_median = prediction_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) # [b,c,h,w]
+ pred_median[torch.isnan(pred_median)] = 0
+ pred_diff = (torch.abs(prediction - pred_median)).reshape((B, C, -1))
+ pred_s = pred_diff.sum(dim=2)[:, :, None, None] / valid_pixs
+ pred_trans = (prediction - pred_median) / (pred_s + self.eps)
+
+ loss_sum = torch.sum(torch.abs(gt_trans - pred_trans)*mask_valid)
+ return loss_sum
+
+ def conditional_ssi_mae(self, prediction, target, mask_valid):
+ B, C, H, W = target.shape
+ conditional_rank_ids = np.random.choice(B, B, replace=False)
+
+ prediction_nan = prediction.clone()
+ target_nan = target.clone()
+ prediction_nan[~mask_valid] = float('nan')
+ target_nan[~mask_valid] = float('nan')
+
+ valid_pixs = mask_valid.reshape((B, C,-1)).sum(dim=2, keepdims=True) + self.eps
+ valid_pixs = valid_pixs[:, :, :, None].contiguous()
+
+ gt_median = target_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) # [b,c,h,w]
+ gt_median[torch.isnan(gt_median)] = 0
+ gt_diff = (torch.abs(target - gt_median) * mask_valid).reshape((B, C,-1))
+ gt_s = gt_diff.sum(dim=2)[:, :, None, None].contiguous() / valid_pixs
+
+ # in case some batches have no valid pixels
+ gt_s_small_mask = gt_s < (torch.mean(gt_s)*0.1)
+ gt_s[gt_s_small_mask] = torch.mean(gt_s)
+ gt_trans = (target - gt_median[conditional_rank_ids]) / (gt_s[conditional_rank_ids] + self.eps)
+
+ pred_median = prediction_nan.reshape((B, C,-1)).nanmedian(2, keepdims=True)[0].unsqueeze(-1) # [b,c,h,w]
+ pred_median[torch.isnan(pred_median)] = 0
+ pred_diff = (torch.abs(prediction - pred_median) * mask_valid).reshape((B, C,-1))
+ pred_s = pred_diff.sum(dim=2)[:, :, None, None].contiguous() / valid_pixs
+ pred_s[gt_s_small_mask] = torch.mean(pred_s)
+ pred_trans = (prediction - pred_median[conditional_rank_ids]) / (pred_s[conditional_rank_ids] + self.eps)
+
+ loss_sum = torch.sum(torch.abs(gt_trans - pred_trans)*mask_valid)
+ # print(torch.abs(gt_trans - pred_trans)[mask_valid])
+ return loss_sum
+
+
+ def forward(self, prediction, target, mask=None, sem_mask=None, **kwargs):
+ """
+ Calculate loss.
+ """
+ B, C, H, W = target.shape
+
+ loss = 0.0
+ valid_pix = 0.0
+
+ device = target.device
+
+ batches_dataset = kwargs['dataset']
+ self.batch_valid = torch.tensor([1 if batch_dataset not in self.disable_dataset else 0 \
+ for batch_dataset in batches_dataset], device=device)[:,None,None,None]
+
+ batch_limit = self.batch_limit
+
+ random_sample_masks = self.get_random_masks_for_batch((H, W)) # [N, H, W]
+ for i in range(B):
+ # each batch
+ mask_i = mask[i, ...] #[1, H, W]
+ if self.batch_valid[i, ...] < 0.5:
+ loss += 0 * torch.sum(prediction[i, ...])
+ valid_pix += 0 * torch.sum(mask_i)
+ continue
+
+ pred_i = prediction[i, ...].unsqueeze(0).repeat(batch_limit, 1, 1, 1)
+ target_i = target[i, ...].unsqueeze(0).repeat(batch_limit, 1, 1, 1)
+
+ # get semantic masks
+ sem_label_i = sem_mask[i, ...] if sem_mask is not None else None
+ if sem_label_i is not None:
+ sem_masks = self.reorder_sem_masks(sem_label_i) # [N, H, W]
+ random_sem_masks = torch.cat([random_sample_masks, sem_masks], dim=0)
+ else:
+ random_sem_masks = random_sample_masks
+ #random_sem_masks = random_sample_masks
+
+
+ sampled_masks_num = random_sem_masks.shape[0]
+ loops = int(np.ceil(sampled_masks_num / batch_limit))
+ conditional_rank_ids = np.random.choice(sampled_masks_num, sampled_masks_num, replace=False)
+
+ for j in range(loops):
+ mask_random_sem_loopi = random_sem_masks[j*batch_limit:(j+1)*batch_limit, ...]
+ mask_sample = (mask_i & mask_random_sem_loopi).unsqueeze(1) # [N, 1, H, W]
+ loss += self.ssi_mae(
+ prediction=pred_i[:mask_sample.shape[0], ...],
+ target=target_i[:mask_sample.shape[0], ...],
+ mask_valid=mask_sample)
+ valid_pix += torch.sum(mask_sample)
+
+ # conditional ssi loss
+ # rerank_mask_random_sem_loopi = random_sem_masks[conditional_rank_ids, ...][j*batch_limit:(j+1)*batch_limit, ...]
+ # rerank_mask_sample = (mask_i & rerank_mask_random_sem_loopi).unsqueeze(1) # [N, 1, H, W]
+ # loss_cond = self.conditional_ssi_mae(
+ # prediction=pred_i[:rerank_mask_sample.shape[0], ...],
+ # target=target_i[:rerank_mask_sample.shape[0], ...],
+ # mask_valid=rerank_mask_sample)
+ # print(loss_cond / (torch.sum(rerank_mask_sample) + 1e-10), loss_cond, torch.sum(rerank_mask_sample))
+ # loss += loss_cond
+ # valid_pix += torch.sum(rerank_mask_sample)
+
+ # crop_y, crop_y_end, crop_x, crop_x_end = self.get_random_masks_for_batch((H, W)) # [N,]
+ # for j in range(B):
+ # for i in range(self.random_num):
+ # mask_crop = mask[j, :, crop_y[i]:crop_y_end[i], crop_x[i]:crop_x_end[i]][None, ...] #[1, 1, crop_h, crop_w]
+ # target_crop = target[j, :, crop_y[i]:crop_y_end[i], crop_x[i]:crop_x_end[i]][None, ...]
+ # pred_crop = prediction[j, :, crop_y[i]:crop_y_end[i], crop_x[i]:crop_x_end[i]][None, ...]
+ # loss += self.ssi_mae(prediction=pred_crop, target=target_crop, mask_valid=mask_crop)
+ # valid_pix += torch.sum(mask_crop)
+
+ # the whole image
+ mask = mask * self.batch_valid.bool()
+ loss += self.ssi_mae(
+ prediction=prediction,
+ target=target,
+ mask_valid=mask)
+ valid_pix += torch.sum(mask)
+ loss = loss / (valid_pix + self.eps)
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ loss = 0 * torch.sum(prediction)
+ print(f'HDSNL NAN error, {loss}, valid pix: {valid_pix}')
+ return loss * self.loss_weight
+
+if __name__ == '__main__':
+ torch.manual_seed(1)
+ torch.cuda.manual_seed_all(1)
+ ssil = HDSNRandomLoss()
+ pred = torch.rand((8, 1, 256, 512)).cuda()
+ gt = torch.rand((8, 1, 256, 512)).cuda()#torch.zeros_like(pred).cuda() #
+ gt[1:, :, 100:256, 100:350] = -1
+ gt[:2, ...] = -1
+ mask = gt > 0
+ sem_mask = np.random.randint(-1, 200, (8, 1, 256, 512))
+ sem_mask[sem_mask>0] = -1
+ sem_mask_torch = torch.from_numpy(sem_mask).cuda()
+
+ out = ssil(pred, gt, mask, sem_mask_torch)
+ print(out)
diff --git a/training/mono/model/losses/L1.py b/training/mono/model/losses/L1.py
new file mode 100644
index 0000000000000000000000000000000000000000..9646e85f313432153cfd10ff746dd61817347be0
--- /dev/null
+++ b/training/mono/model/losses/L1.py
@@ -0,0 +1,63 @@
+import torch
+import torch.nn as nn
+
+class L1Loss(nn.Module):
+ """
+ Compute L1 loss.
+ """
+ def __init__(self, loss_weight=1, data_type=['lidar', 'denselidar', 'stereo', 'denselidar_syn'], **kwargs):
+ super(L1Loss, self).__init__()
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+ self.eps = 1e-6
+
+ def forward(self, prediction, target, mask=None, **kwargs):
+ diff = torch.abs(prediction - target)* mask
+ loss = torch.sum(diff) / (torch.sum(mask) + self.eps)
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ loss = 0 * torch.sum(prediction)
+ print(f'L1 NAN error, {loss}')
+ #raise RuntimeError(f'Silog error, {loss}, d_square_mean: {d_square_mean}, d_mean: {d_mean}')
+ return loss * self.loss_weight
+
+class L1DispLoss(nn.Module):
+ """
+ Compute L1 disparity loss of disparity.
+ """
+ def __init__(self, loss_weight=1, data_type=['lidar', 'denselidar', 'stereo', 'denselidar_syn'], **kwargs):
+ super(L1DispLoss, self).__init__()
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+ self.eps = 1e-6
+
+ def forward(self, prediction_disp, inv_depth, mask=None, **kwargs):
+ # gt_disp_mask = ~torch.all(inv_depth == 0, dim=1, keepdim=True)
+ # if mask is None:
+ # mask = gt_disp_mask
+ diff = torch.abs(prediction_disp - inv_depth)* mask
+ loss = torch.sum(diff) / (torch.sum(mask) + self.eps)
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ loss = 0 * torch.sum(prediction_disp)
+ #raise RuntimeError(f'Silog error, {loss}, d_square_mean: {d_square_mean}, d_mean: {d_mean}')
+ return loss * self.loss_weight
+
+class L1InverseLoss(nn.Module):
+ """
+ Compute L1 disparity loss of disparity.
+ """
+ def __init__(self, loss_weight=1, data_type=['lidar', 'denselidar', 'stereo'], **kwargs):
+ super(L1InverseLoss, self).__init__()
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+ self.eps = 1e-6
+
+ def forward(self, prediction, inv_depth, mask=None, **kwargs):
+ mask = torch.logical_and(mask, inv_depth>0)
+ inv_pred = 1.0 / prediction * 10.0
+ inv_pred[~mask] = -1
+ diff = torch.abs(inv_pred - inv_depth)* mask
+ loss = torch.sum(diff) / (torch.sum(mask) + self.eps)
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ loss = 0 * torch.sum(inv_pred)
+ #raise RuntimeError(f'Silog error, {loss}, d_square_mean: {d_square_mean}, d_mean: {d_mean}')
+ return loss * self.loss_weight
\ No newline at end of file
diff --git a/training/mono/model/losses/NormalBranchLoss.py b/training/mono/model/losses/NormalBranchLoss.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4ecd2ec2ebb0ab10fc4305b80bb2e527b220c6d
--- /dev/null
+++ b/training/mono/model/losses/NormalBranchLoss.py
@@ -0,0 +1,732 @@
+import torch
+import torch.nn as nn
+import numpy as np
+import torch.nn.functional as F
+from .depth_to_normal import Depth2Normal
+
+# compute loss
+class NormalBranchLoss(nn.Module):
+ def __init__(self, loss_weight=1.0, data_type=['sfm', 'stereo', 'denselidar', 'denselidar_syn'], d2n_dataset=['ScanNetAll'], loss_fn='UG_NLL_ours', **kwargs):
+ """loss_fn can be one of following:
+ - L1 - L1 loss (no uncertainty)
+ - L2 - L2 loss (no uncertainty)
+ - AL - Angular loss (no uncertainty)
+ - NLL_vMF - NLL of vonMF distribution
+ - NLL_ours - NLL of Angular vonMF distribution
+ - UG_NLL_vMF - NLL of vonMF distribution (+ pixel-wise MLP + uncertainty-guided sampling)
+ - UG_NLL_ours - NLL of Angular vonMF distribution (+ pixel-wise MLP + uncertainty-guided sampling)
+ - NLL_ours_GRU - NLL of Angular vonMF distribution for GRU sequence
+ """
+ super(NormalBranchLoss, self).__init__()
+ self.loss_type = loss_fn
+ if self.loss_type in ['L1', 'L2', 'AL', 'NLL_vMF', 'NLL_ours']:
+ # self.loss_fn = self.forward_R
+ raise NotImplementedError
+ elif self.loss_type in ['UG_NLL_vMF']:
+ # self.loss_fn = self.forward_UG
+ raise NotImplementedError
+ elif self.loss_type in ['UG_NLL_ours']:
+ self.loss_fn = self.forward_UG
+ elif self.loss_type in ['NLL_ours_GRU', 'NLL_ours_GRU_auxi']:
+ self.loss_type = 'NLL_ours'
+ self.loss_fn = self.forward_GRU
+ self.loss_gamma = 0.9
+ try:
+ self.loss_weight_auxi = kwargs['loss_weight_auxi']
+ except:
+ self.loss_weight_auxi = 0.0
+ else:
+ raise Exception('invalid loss type')
+
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+
+ #self.d2n_dataset = d2n_dataset
+ #self.depth2normal = Depth2Normal()
+
+
+
+ def forward(self, **kwargs):
+ # device = kwargs['mask'].device
+ # B, _, H, W = kwargs['mask'].shape
+ # pad_mask = torch.zeros_like(kwargs['mask'], device=device)
+ # for b in range(B):
+ # pad = kwargs['pad'][b].squeeze()
+ # pad_mask[b, :, pad[0]:H-pad[1], pad[2]:W-pad[3]] = True
+
+ # loss = self.loss_fn(pad_mask=pad_mask, **kwargs)
+ loss = self.loss_fn(**kwargs)
+
+ return loss * self.loss_weight
+
+
+ def forward_GRU(self, normal_out_list, normal, target, mask, intrinsic, pad_mask=None, auxi_normal=None, **kwargs):
+ n_predictions = len(normal_out_list)
+ assert n_predictions >= 1
+ loss = 0.0
+
+ # device = pad_mask.device
+ # batches_dataset = kwargs['dataset']
+ # self.batch_with_d2n = torch.tensor([0 if batch_dataset not in self.d2n_dataset else 1 \
+ # for batch_dataset in batches_dataset], device=device)[:,None,None,None]
+
+ # scale = kwargs['scale'][:, None, None].float()
+ # normal_d2n, new_mask_d2n = self.depth2normal(target, intrinsic, pad_mask, scale)
+
+ gt_normal_mask = ~torch.all(normal == 0, dim=1, keepdim=True) & mask
+
+ if auxi_normal != None:
+ auxi_normal_mask = ~gt_normal_mask
+
+ #normal = normal * (1 - self.batch_with_d2n) + normal_d2n * self.batch_with_d2n
+ # gt_normal_mask = gt_normal_mask * (1 - self.batch_with_d2n) + mask * new_mask_d2n * self.batch_with_d2n
+
+ if gt_normal_mask.sum() < 10:
+ if auxi_normal == None:
+ for norm_out in normal_out_list:
+ loss += norm_out.sum() * 0
+ return loss
+
+ for i, norm_out in enumerate(normal_out_list):
+ # We adjust the loss_gamma so it is consistent for any number of RAFT-Stereo iterations
+ adjusted_loss_gamma = self.loss_gamma**(15/(n_predictions - 1))
+ i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
+
+ curr_loss = self.forward_R(norm_out.clone(), normal, gt_normal_mask)
+ if auxi_normal != None:
+ auxi_loss = self.forward_R(norm_out.clone(), auxi_normal[:, :3], auxi_normal_mask)
+ curr_loss = curr_loss + self.loss_weight_auxi * auxi_loss
+
+ if torch.isnan(curr_loss).item() | torch.isinf(curr_loss).item():
+ curr_loss = 0 * torch.sum(norm_out)
+ print(f'NormalBranchLoss forward_GRU NAN error, {curr_loss}')
+
+ loss += curr_loss * i_weight
+
+ return loss
+
+ def forward_R(self, norm_out, gt_norm, gt_norm_mask):
+ pred_norm, pred_kappa = norm_out[:, 0:3, :, :], norm_out[:, 3:, :, :]
+
+ if self.loss_type == 'L1':
+ l1 = torch.sum(torch.abs(gt_norm - pred_norm), dim=1, keepdim=True)
+ loss = torch.mean(l1[gt_norm_mask])
+
+ elif self.loss_type == 'L2':
+ l2 = torch.sum(torch.square(gt_norm - pred_norm), dim=1, keepdim=True)
+ loss = torch.mean(l2[gt_norm_mask])
+
+ elif self.loss_type == 'AL':
+ dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
+
+ valid_mask = gt_norm_mask[:, 0, :, :].float() \
+ * (dot.detach() < 0.999).float() \
+ * (dot.detach() > -0.999).float()
+ valid_mask = valid_mask > 0.0
+
+ al = torch.acos(dot[valid_mask])
+ loss = torch.mean(al)
+
+ elif self.loss_type == 'NLL_vMF':
+ dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
+
+ valid_mask = gt_norm_mask[:, 0, :, :].float() \
+ * (dot.detach() < 0.999).float() \
+ * (dot.detach() > -0.999).float()
+ valid_mask = valid_mask > 0.0
+
+ dot = dot[valid_mask]
+ kappa = pred_kappa[:, 0, :, :][valid_mask]
+
+ loss_pixelwise = - torch.log(kappa) \
+ - (kappa * (dot - 1)) \
+ + torch.log(1 - torch.exp(- 2 * kappa))
+ loss = torch.mean(loss_pixelwise)
+
+ elif self.loss_type == 'NLL_ours':
+ dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
+
+ valid_mask = gt_norm_mask[:, 0, :, :].float() \
+ * (dot.detach() < 0.999).float() \
+ * (dot.detach() > -0.999).float()
+ valid_mask = valid_mask > 0.5
+
+ dot = dot[valid_mask]
+ kappa = pred_kappa[:, 0, :, :][valid_mask]
+
+ loss_pixelwise = - torch.log(torch.square(kappa) + 1) \
+ + kappa * torch.acos(dot) \
+ + torch.log(1 + torch.exp(-kappa * np.pi))
+ loss = torch.mean(loss_pixelwise)
+
+ else:
+ raise Exception('invalid loss type')
+
+ return loss
+
+
+ def forward_UG(self, normal_pred_list, normal_coord_list, normal, mask, **kwargs):
+ gt_normal_mask = ~torch.all(normal == 0, dim=1, keepdim=True) & mask
+
+ # gt_norm = norms[0]
+ # gt_normal_mask = (gt_norm[:, 0:1, :, :] == 0) & (gt_norm[:, 1:2, :, :] == 0) & (gt_norm[:, 2:3, :, :] == 0)
+ # gt_normal_mask = ~gt_normal_mask
+ loss = 0.0
+
+ if gt_normal_mask[gt_normal_mask].numel() < 10:
+ for (pred, coord) in zip(normal_pred_list, normal_coord_list):
+ if pred is not None:
+ loss += pred.sum() * 0.
+ if coord is not None:
+ loss += coord.sum() * 0.
+ return loss
+
+
+ for (pred, coord) in zip(normal_pred_list, normal_coord_list):
+ if coord is None:
+ pred = F.interpolate(pred, size=[normal.size(2), normal.size(3)], mode='bilinear', align_corners=True)
+ pred_norm, pred_kappa = pred[:, 0:3, :, :], pred[:, 3:, :, :]
+
+ # if self.loss_type == 'UG_NLL_vMF':
+ # dot = torch.cosine_similarity(pred_norm, normal, dim=1)
+
+ # valid_mask = normal_mask[:, 0, :, :].float() \
+ # * (dot.detach() < 0.999).float() \
+ # * (dot.detach() > -0.999).float()
+ # valid_mask = valid_mask > 0.5
+
+ # # mask
+ # dot = dot[valid_mask]
+ # kappa = pred_kappa[:, 0, :, :][valid_mask]
+
+ # loss_pixelwise = - torch.log(kappa) \
+ # - (kappa * (dot - 1)) \
+ # + torch.log(1 - torch.exp(- 2 * kappa))
+ # loss = loss + torch.mean(loss_pixelwise)
+
+ if self.loss_type == 'UG_NLL_ours':
+ dot = torch.cosine_similarity(pred_norm, normal, dim=1)
+
+ valid_mask = gt_normal_mask[:, 0, :, :].float() \
+ * (dot.detach() < 0.999).float() \
+ * (dot.detach() > -0.999).float()
+ valid_mask = valid_mask > 0.5
+
+ dot = dot[valid_mask]
+ kappa = pred_kappa[:, 0, :, :][valid_mask]
+
+ loss_pixelwise = - torch.log(torch.square(kappa) + 1) \
+ + kappa * torch.acos(dot) \
+ + torch.log(1 + torch.exp(-kappa * np.pi))
+ loss = loss + torch.mean(loss_pixelwise)
+
+ else:
+ raise Exception
+
+ else:
+ # coord: B, 1, N, 2
+ # pred: B, 4, N
+ gt_norm_ = F.grid_sample(normal, coord, mode='nearest', align_corners=True) # (B, 3, 1, N)
+ gt_norm_mask_ = F.grid_sample(gt_normal_mask.float(), coord, mode='nearest', align_corners=True) # (B, 1, 1, N)
+ gt_norm_ = gt_norm_[:, :, 0, :] # (B, 3, N)
+ gt_norm_mask_ = gt_norm_mask_[:, :, 0, :] > 0.5 # (B, 1, N)
+
+ pred_norm, pred_kappa = pred[:, 0:3, :], pred[:, 3:, :]
+
+ # if self.loss_type == 'UG_NLL_vMF':
+ # dot = torch.cosine_similarity(pred_norm, gt_norm_, dim=1) # (B, N)
+
+ # valid_mask = gt_norm_mask_[:, 0, :].float() \
+ # * (dot.detach() < 0.999).float() \
+ # * (dot.detach() > -0.999).float()
+ # valid_mask = valid_mask > 0.5
+
+ # dot = dot[valid_mask]
+ # kappa = pred_kappa[:, 0, :][valid_mask]
+
+ # loss_pixelwise = - torch.log(kappa) \
+ # - (kappa * (dot - 1)) \
+ # + torch.log(1 - torch.exp(- 2 * kappa))
+ # loss = loss + torch.mean(loss_pixelwise)
+
+ if self.loss_type == 'UG_NLL_ours':
+ dot = torch.cosine_similarity(pred_norm, gt_norm_, dim=1) # (B, N)
+
+ valid_mask = gt_norm_mask_[:, 0, :].float() \
+ * (dot.detach() < 0.999).float() \
+ * (dot.detach() > -0.999).float()
+ valid_mask = valid_mask > 0.5
+
+ dot = dot[valid_mask]
+ kappa = pred_kappa[:, 0, :][valid_mask]
+
+ loss_pixelwise = - torch.log(torch.square(kappa) + 1) \
+ + kappa * torch.acos(dot) \
+ + torch.log(1 + torch.exp(-kappa * np.pi))
+ loss = loss + torch.mean(loss_pixelwise)
+
+ else:
+ raise Exception
+ return loss
+
+
+
+
+# confidence-guided sampling
+@torch.no_grad()
+def sample_points(init_normal, confidence_map, gt_norm_mask, sampling_ratio, beta=1):
+ device = init_normal.device
+ B, _, H, W = init_normal.shape
+ N = int(sampling_ratio * H * W)
+ beta = beta
+
+ # confidence map
+ # confidence_map = init_normal[:, 3, :, :] # B, H, W
+
+ # gt_invalid_mask (B, H, W)
+ if gt_norm_mask is not None:
+ gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest')
+ gt_invalid_mask = gt_invalid_mask < 0.5
+ confidence_map[gt_invalid_mask] = -1e4
+
+ # (B, H*W)
+ _, idx = confidence_map.view(B, -1).sort(1, descending=True)
+
+ # confidence sampling
+ if int(beta * N) > 0:
+ importance = idx[:, :int(beta * N)] # B, beta*N
+
+ # remaining
+ remaining = idx[:, int(beta * N):] # B, H*W - beta*N
+
+ # coverage
+ num_coverage = N - int(beta * N)
+
+ if num_coverage <= 0:
+ samples = importance
+ else:
+ coverage_list = []
+ for i in range(B):
+ idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
+ coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
+ coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
+ samples = torch.cat((importance, coverage), dim=1) # B, N
+
+ else:
+ # remaining
+ remaining = idx[:, :] # B, H*W
+
+ # coverage
+ num_coverage = N
+
+ coverage_list = []
+ for i in range(B):
+ idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N"
+ coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N
+ coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N
+ samples = coverage
+
+ # point coordinates
+ rows_int = samples // W # 0 for first row, H-1 for last row
+ # rows_float = rows_int / float(H-1) # 0 to 1.0
+ # rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0
+
+ cols_int = samples % W # 0 for first column, W-1 for last column
+ # cols_float = cols_int / float(W-1) # 0 to 1.0
+ # cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0
+
+ # point_coords = torch.zeros(B, 1, N, 2)
+ # point_coords[:, 0, :, 0] = cols_float # x coord
+ # point_coords[:, 0, :, 1] = rows_float # y coord
+ # point_coords = point_coords.to(device)
+ # return point_coords, rows_int, cols_int
+
+ sample_mask = torch.zeros((B,1,H,W), dtype=torch.bool, device=device)
+ for i in range(B):
+ sample_mask[i, :, rows_int[i,:], cols_int[i,:]] = True
+ return sample_mask
+
+# depth-normal consistency loss
+class DeNoConsistencyLoss(nn.Module):
+ def __init__(self, loss_weight=1.0, data_type=['stereo', 'lidar', 'denselidar', 'denselidar_nometric', 'denselidar_syn'], loss_fn='NLL_ours', \
+ sky_id=142, scale=1, norm_dataset=['Taskonomy', 'Matterport3D', 'Replica', 'Hypersim', 'NYU'], no_sky_dataset=['BigData', 'DIODE', 'Completion', 'Matterport3D'], disable_dataset=[], depth_detach=False, **kwargs):
+ """loss_fn can be one of following:
+ - L1 - L1 loss (no uncertainty)
+ - L2 - L2 loss (no uncertainty)
+ - AL - Angular loss (no uncertainty)
+ - NLL_vMF - NLL of vonMF distribution
+ - NLL_ours - NLL of Angular vonMF distribution
+ - UG_NLL_vMF - NLL of vonMF distribution (+ pixel-wise MLP + uncertainty-guided sampling)
+ - UG_NLL_ours - NLL of Angular vonMF distribution (+ pixel-wise MLP + uncertainty-guided sampling)
+ - NLL_ours_GRU - NLL of Angular vonMF distribution for GRU sequence
+ - CEL - cosine embedding loss
+ - CEL_GRU
+ """
+ super(DeNoConsistencyLoss, self).__init__()
+ self.loss_type = loss_fn
+ if self.loss_type in ['L1', 'L2', 'NLL_vMF']:
+ # self.loss_fn = self.forward_R
+ raise NotImplementedError
+ elif self.loss_type in ['UG_NLL_vMF']:
+ # self.loss_fn = self.forward_UG
+ raise NotImplementedError
+ elif self.loss_type in ['UG_NLL_ours']:
+ # self.loss_fn = self.forward_UG
+ raise NotImplementedError
+ elif self.loss_type in ['NLL_ours']:
+ self.loss_fn = self.forward_J # confidence Joint optimization
+ self.loss_gamma = 0.9
+ elif self.loss_type in ['AL', 'CEL', 'CEL_L2']:
+ self.loss_fn = self.forward_S # confidence Sample
+ elif self.loss_type in ['CEL_GRU']:
+ self.loss_fn = self.forward_S_GRU # gru
+ self.loss_gamma = 0.9
+ elif 'Search' in self.loss_type:
+ self.loss_fn = self.forward_S_Search
+ else:
+ raise Exception('invalid loss type')
+
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+ self.sky_id = sky_id
+
+ # For datasets without surface normal gt, enhance its weight (decrease the weight of the dataset with gt).
+ self.nonorm_data_scale = scale
+ self.norm_dataset = norm_dataset
+ self.no_sky_dataset = no_sky_dataset
+ self.disable_dataset = disable_dataset
+
+ self.depth_detach = depth_detach
+ self.depth2normal = Depth2Normal()
+
+ def forward(self, **kwargs):
+ device = kwargs['mask'].device
+
+ batches_dataset = kwargs['dataset']
+ self.batch_with_norm = torch.tensor([self.nonorm_data_scale if batch_dataset not in self.norm_dataset else 1 \
+ for batch_dataset in batches_dataset], device=device)[:,None,None,None]
+
+ self.batch_enabled= torch.tensor([1 if batch_dataset not in self.disable_dataset else 0 \
+ for batch_dataset in batches_dataset], device=device, dtype=torch.bool)[:,None,None,None]
+ self.batch_with_norm = self.batch_with_norm * self.batch_enabled
+
+
+ self.batch_with_norm_sky = torch.tensor([1 if batch_dataset not in self.no_sky_dataset else 0 \
+ for batch_dataset in batches_dataset], device=device, dtype=torch.bool)[:,None,None,None]
+
+ B, _, H, W = kwargs['mask'].shape
+ pad_mask = torch.zeros_like(kwargs['mask'], device=device)
+ for b in range(B):
+ pad = kwargs['pad'][b].squeeze()
+ pad_mask[b, :, pad[0]:H-pad[1], pad[2]:W-pad[3]] = True
+
+ loss = self.loss_fn(pad_mask=pad_mask, **kwargs)
+ return loss * self.loss_weight
+
+
+ def forward_J(self, prediction, confidence, normal_out_list, intrinsic, pad_mask, sem_mask=None, **kwargs):
+ prediction_normal = normal_out_list[-1].clone()
+
+ # get normal from depth-prediction
+ normal, new_mask = self.depth2normal(prediction.detach() if self.depth_detach else prediction, intrinsic, pad_mask)
+ # mask sky
+ sky_mask = sem_mask != self.sky_id
+ new_mask = new_mask & sky_mask
+ # normal = normal * (~sky_mask)
+ # normal[:,1:2,:,:][sky_mask] = 1
+ # confidence sampling (sample good depth -> good normal -> to )
+ sample_mask_d = sample_points(prediction, confidence, new_mask, sampling_ratio=0.7)
+
+ # all mask
+ normal_mask = ~torch.all(normal == 0, dim=1, keepdim=True) & new_mask & sample_mask_d
+ if normal_mask.sum() < 10:
+ return 0 * prediction_normal.sum()
+
+ loss = self.forward_R(prediction_normal, normal, normal_mask)
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ loss = 0 * torch.sum(prediction_normal)
+ print(f'NormalBranchLoss forward_GRU NAN error, {loss}')
+
+ return loss
+
+ #def forward_S(self, prediction, confidence, normal_out_list, intrinsic, pad_mask, sem_mask=None, **kwargs):
+ def forward_S(self, prediction, confidence, intrinsic, pad_mask, normal_pred=None, sem_mask=None, target=None, is_initial_pair=False, **kwargs):
+
+ if normal_pred is None:
+ prediction_normal = kwargs['normal_out_list'][-1]
+ else:
+ prediction_normal = normal_pred
+
+ # get normal from depth-prediction
+ #try:
+ scale = kwargs['scale'][:, None, None].float()
+ #except:
+ #scale = 1.0
+ normal, new_mask = self.depth2normal(prediction.detach() if self.depth_detach else prediction, intrinsic, pad_mask, scale)
+
+ sky_mask = sem_mask != self.sky_id
+ if target != None:
+ sampling_ratio = 0.7
+ target_mask = (target > 0)
+ if is_initial_pair == False:
+ pass
+ # mask sky
+ else:
+ sky_mask = torch.nn.functional.interpolate(sky_mask.float(), scale_factor=0.25).bool()
+ target_mask = torch.nn.functional.interpolate(target_mask.float(), scale_factor=0.25).bool()
+ new_mask = new_mask & ((sky_mask & self.batch_with_norm_sky) | target_mask)
+ else:
+ new_mask = torch.ones_like(prediction).bool()
+ sampling_ratio = 0.5
+
+ # normal = normal * (~sky_mask)
+ # normal[:,1:2,:,:][sky_mask] = 1
+
+ # dual sampling
+ confidence_normal = prediction_normal[:, 3:, :, :]
+ sample_mask_n = sample_points(prediction_normal, confidence_normal, new_mask, sampling_ratio=sampling_ratio)
+ sample_mask_d = sample_points(prediction, confidence, new_mask, sampling_ratio=sampling_ratio)
+ conf_mask = confidence > 0.5
+
+ # all mask
+ normal_mask = ~torch.all(normal == 0, dim=1, keepdim=True) & new_mask & sample_mask_n & sample_mask_d & conf_mask
+ if normal_mask.sum() < 10:
+ return 0 * prediction_normal.sum()
+
+ loss = self.forward_R(prediction_normal, normal, normal_mask)
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ loss = 0 * torch.sum(prediction_normal)
+ print(f'NormalBranchLoss forward_GRU NAN error, {loss}')
+
+ return loss
+
+ def forward_S_GRU(self, predictions_list, confidence_list, normal_out_list, intrinsic, pad_mask, sem_mask, target, low_resolution_init, **kwargs):
+ n_predictions = len(normal_out_list)
+ assert n_predictions >= 1
+ loss = 0.0
+
+ for i, (norm, conf, depth) in enumerate(zip(normal_out_list, confidence_list, predictions_list)):
+ # We adjust the loss_gamma so it is consistent for any number of RAFT-Stereo iterations
+ adjusted_loss_gamma = self.loss_gamma**(15/(n_predictions - 1))
+ i_weight = adjusted_loss_gamma**(n_predictions - i - 1)
+
+ if i == 0:
+ is_initial_pair = True
+ new_intrinsic = torch.cat((intrinsic[:, :2, :]/4, intrinsic[:, 2:3, :]), dim=1)
+ curr_loss = self.forward_S(low_resolution_init[0], low_resolution_init[1], new_intrinsic, torch.nn.functional.interpolate(pad_mask.float(), scale_factor=0.25).bool(), low_resolution_init[2], sem_mask, target, is_initial_pair, scale=kwargs['scale'])
+ else:
+ is_initial_pair = False
+ curr_loss = self.forward_S(depth, conf, intrinsic, pad_mask, norm, sem_mask, target, is_initial_pair, scale=kwargs['scale'])
+
+ if torch.isnan(curr_loss).item() | torch.isinf(curr_loss).item():
+ curr_loss = 0 * torch.sum(norm)
+ print(f'NormalBranchLoss forward_GRU NAN error, {curr_loss}')
+
+ loss += curr_loss * i_weight
+
+ return loss
+
+
+ def forward_R(self, norm_out, gt_norm, gt_norm_mask, pred_kappa=None):
+ pred_norm = norm_out[:, 0:3, :, :]
+ if pred_kappa is None:
+ pred_kappa = norm_out[:, 3:, :, :]
+
+ if self.loss_type == 'L1':
+ l1 = torch.sum(torch.abs(gt_norm - pred_norm), dim=1, keepdim=True)
+ loss = torch.mean(l1[gt_norm_mask])
+
+ elif self.loss_type == 'L2' or self.loss_type == 'CEL_L2':
+ l2 = torch.sum(torch.square(gt_norm - pred_norm), dim=1, keepdim=True)
+ loss = torch.mean(l2[gt_norm_mask])
+
+ elif self.loss_type == 'AL':
+ dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
+
+ valid_mask = gt_norm_mask[:, 0, :, :].float() \
+ * (dot.detach() < 0.999).float() \
+ * (dot.detach() > -0.999).float()
+ valid_mask = valid_mask > 0.0
+
+ al = torch.acos(dot * valid_mask)
+ al = al * self.batch_with_norm[:, 0, :, :]
+ loss = torch.mean(al)
+
+ elif self.loss_type == 'CEL' or self.loss_type == 'CEL_GRU':
+ dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
+
+ valid_mask = gt_norm_mask[:, 0, :, :].float() \
+ * (dot.detach() < 0.999).float() \
+ * (dot.detach() > -0.999).float()
+ valid_mask = valid_mask > 0.0
+
+ al = 1 - dot * valid_mask
+ al = al * self.batch_with_norm[:, 0, :, :]
+ loss = torch.mean(al)
+
+ elif self.loss_type == 'NLL_vMF':
+ dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
+
+ valid_mask = gt_norm_mask[:, 0, :, :].float() \
+ * (dot.detach() < 0.999).float() \
+ * (dot.detach() > -0.999).float()
+ valid_mask = valid_mask > 0.0
+
+ dot = dot[valid_mask]
+ kappa = pred_kappa[:, 0, :, :][valid_mask]
+
+ loss_pixelwise = - torch.log(kappa) \
+ - (kappa * (dot - 1)) \
+ + torch.log(1 - torch.exp(- 2 * kappa))
+ loss = torch.mean(loss_pixelwise)
+
+ elif self.loss_type == 'NLL_ours':
+ dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
+
+ valid_mask = gt_norm_mask[:, 0, :, :].float() \
+ * (dot.detach() < 0.999).float() \
+ * (dot.detach() > -0.999).float()
+ valid_mask = valid_mask > 0.5
+
+ dot = dot * valid_mask
+ kappa = pred_kappa[:, 0, :, :] * valid_mask
+
+ loss_pixelwise = - torch.log(torch.square(kappa) + 1) \
+ + kappa * torch.acos(dot) \
+ + torch.log(1 + torch.exp(-kappa * np.pi))
+ loss_pixelwise = loss_pixelwise * self.batch_with_norm[:, 0, :, :]
+ loss = torch.mean(loss_pixelwise)
+
+ else:
+ raise Exception('invalid loss type')
+
+ return loss
+
+ def forward_S_Search(self, prediction, confidence, intrinsic, pad_mask, normal_pred=None, sem_mask=None, target=None, is_initial_pair=False, **kwargs):
+
+ if normal_pred is None:
+ prediction_normal = kwargs['normal_out_list'][-1]
+ else:
+ prediction_normal = normal_pred
+
+ # get normal from depth-prediction
+ scale = kwargs['scale'][:, None, None].float()
+ candidate_scale = kwargs['candidate_scale'][:, None, None, None].float()
+ normal, new_mask = self.depth2normal(prediction.detach() if self.depth_detach else prediction, intrinsic, pad_mask, scale)
+
+ sky_mask = sem_mask != self.sky_id
+ if target != None:
+ sampling_ratio = 0.7
+ target_mask = (target > 0)
+ if is_initial_pair == False:
+ pass
+ # mask sky
+ else:
+ sky_mask = torch.nn.functional.interpolate(sky_mask.float(), scale_factor=0.25).bool()
+ target_mask = torch.nn.functional.interpolate(target_mask.float(), scale_factor=0.25).bool()
+ new_mask = new_mask & ((sky_mask & self.batch_with_norm_sky) | target_mask)
+ else:
+ new_mask = torch.ones_like(prediction).bool()
+ sampling_ratio = 0.5
+
+ # normal = normal * (~sky_mask)
+ # normal[:,1:2,:,:][sky_mask] = 1
+
+ # dual sampling
+ confidence_normal = prediction_normal[:, 3:, :, :]
+ sample_mask_n = sample_points(prediction_normal, confidence_normal, new_mask, sampling_ratio=sampling_ratio)
+ sample_mask_d = sample_points(prediction, confidence, new_mask, sampling_ratio=sampling_ratio)
+ conf_mask = confidence > 0.5
+
+ # all mask
+ normal_mask = ~torch.all(normal == 0, dim=1, keepdim=True) & new_mask & sample_mask_n & sample_mask_d & conf_mask
+ if normal_mask.sum() < 10:
+ return 0 * prediction_normal.sum()
+
+ prediction_normal = torch.cat((prediction_normal[:,:2]*torch.ones_like(candidate_scale), prediction_normal[:,2:3]*candidate_scale, prediction_normal[:,3:4]*torch.ones_like(candidate_scale)), dim=1)
+
+ norm_x = prediction_normal[:,0:1]
+ norm_y = prediction_normal[:,1:2]
+ norm_z = prediction_normal[:,2:3]
+
+ prediction_normal[:,:3] = prediction_normal[:,:3] / (torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10)
+
+ loss = self.forward_R_Search(prediction_normal, normal, normal_mask)
+ #if torch.isnan(loss).item() | torch.isinf(loss).item():
+ #loss = 0 * torch.sum(prediction_normal)
+ #print(f'NormalBranchLoss forward_GRU NAN error, {loss}')
+
+ return loss
+
+
+ def forward_R_Search(self, norm_out, gt_norm, gt_norm_mask, pred_kappa=None):
+ pred_norm = norm_out[:, 0:3, :, :]
+ if pred_kappa is None:
+ pred_kappa = norm_out[:, 3:, :, :]
+
+ if 'L1' in self.loss_type:
+ l1 = torch.sum(torch.abs(gt_norm - pred_norm), dim=1, keepdim=True)
+ loss = torch.mean(l1*gt_norm_mask, dim=[1, 2, 3])
+
+ elif 'L2' in self.loss_type:
+ l2 = torch.sum(torch.square(gt_norm - pred_norm), dim=1, keepdim=True)
+ loss = torch.mean(l2*gt_norm_mask, dim=[1, 2, 3])
+
+ elif 'AL' in self.loss_type:
+ dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
+
+ valid_mask = gt_norm_mask[:, 0, :, :].float() \
+ * (dot.detach() < 0.999).float() \
+ * (dot.detach() > -0.999).float()
+ valid_mask = valid_mask > 0.0
+
+ al = torch.acos(dot * valid_mask)
+ loss = torch.mean(al, dim=[1, 2])
+
+ elif 'CEL' in self.loss_type:
+ dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
+
+ valid_mask = gt_norm_mask[:, 0, :, :].float() \
+ * (dot.detach() < 0.999).float() \
+ * (dot.detach() > -0.999).float()
+ valid_mask = valid_mask > 0.0
+
+ al = 1 - dot * valid_mask
+ loss = torch.mean(al, dim=[1, 2])
+
+ elif 'NLL_vMF' in self.loss_type:
+ dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
+
+ valid_mask = gt_norm_mask[:, 0, :, :].float() \
+ * (dot.detach() < 0.999).float() \
+ * (dot.detach() > -0.999).float()
+ valid_mask = valid_mask > 0.0
+
+ dot = dot[valid_mask]
+ kappa = pred_kappa[:, 0, :, :][valid_mask]
+
+ loss_pixelwise = - torch.log(kappa) \
+ - (kappa * (dot - 1)) \
+ + torch.log(1 - torch.exp(- 2 * kappa))
+ loss = torch.mean(loss_pixelwise, dim=[1, 2])
+
+ elif 'NLL_ours' in self.loss_type:
+ dot = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
+
+ valid_mask = gt_norm_mask[:, 0, :, :].float() \
+ * (dot.detach() < 0.999).float() \
+ * (dot.detach() > -0.999).float()
+ valid_mask = valid_mask > 0.5
+
+ dot = dot * valid_mask
+ kappa = pred_kappa[:, 0, :, :] * valid_mask
+
+ loss_pixelwise = - torch.log(torch.square(kappa) + 1) \
+ + kappa * torch.acos(dot) \
+ + torch.log(1 + torch.exp(-kappa * np.pi))
+ loss = torch.mean(loss_pixelwise, dim=[1, 2])
+
+ else:
+ raise Exception('invalid loss type')
+
+ return loss
\ No newline at end of file
diff --git a/training/mono/model/losses/NormalRegression.py b/training/mono/model/losses/NormalRegression.py
new file mode 100644
index 0000000000000000000000000000000000000000..00b7169de6f6c5753224d0cdd6ad57d5397505a6
--- /dev/null
+++ b/training/mono/model/losses/NormalRegression.py
@@ -0,0 +1,418 @@
+import torch
+from torch import nn
+import numpy as np
+import torch.nn.functional as F
+from .depth_to_normal import Depth2Normal
+"""
+Sampling strategies: RS (Random Sampling), EGS (Edge-Guided Sampling), and IGS (Instance-Guided Sampling)
+"""
+###########
+# RANDOM SAMPLING
+# input:
+# inputs[i,:], targets[i, :], masks[i, :], self.mask_value, self.point_pairs
+# return:
+# inputs_A, inputs_B, targets_A, targets_B, consistent_masks_A, consistent_masks_B
+###########
+def randomSamplingNormal(inputs, targets, masks, sample_num):
+
+ # find A-B point pairs from prediction
+ num_effect_pixels = torch.sum(masks)
+ shuffle_effect_pixels = torch.randperm(num_effect_pixels, device="cuda")
+ valid_inputs = inputs[:, masks]
+ valid_targes = targets[:, masks]
+ inputs_A = valid_inputs[:, shuffle_effect_pixels[0 : sample_num * 2 : 2]]
+ inputs_B = valid_inputs[:, shuffle_effect_pixels[1 : sample_num * 2 : 2]]
+ # find corresponding pairs from GT
+ targets_A = valid_targes[:, shuffle_effect_pixels[0 : sample_num * 2 : 2]]
+ targets_B = valid_targes[:, shuffle_effect_pixels[1 : sample_num * 2 : 2]]
+ if inputs_A.shape[1] != inputs_B.shape[1]:
+ num_min = min(targets_A.shape[1], targets_B.shape[1])
+ inputs_A = inputs_A[:, :num_min]
+ inputs_B = inputs_B[:, :num_min]
+ targets_A = targets_A[:, :num_min]
+ targets_B = targets_B[:, :num_min]
+ return inputs_A, inputs_B, targets_A, targets_B
+
+
+###########
+# EDGE-GUIDED SAMPLING
+# input:
+# inputs[i,:], targets[i, :], masks[i, :], edges_img[i], thetas_img[i], masks[i, :], h, w
+# return:
+# inputs_A, inputs_B, targets_A, targets_B, masks_A, masks_B
+###########
+def ind2sub(idx, cols):
+ r = torch.div(idx, cols, rounding_mode='floor')
+ c = idx - r * cols
+ return r, c
+
+
+def sub2ind(r, c, cols):
+ idx = r * cols + c
+ return idx
+
+
+def edgeGuidedSampling(inputs, targets, edges_img, thetas_img, masks, h, w):
+ # find edges
+ edges_max = edges_img.max()
+ edges_min = edges_img.min()
+ edges_mask = edges_img.ge(edges_max * 0.1)
+ edges_loc = edges_mask.nonzero(as_tuple=False)
+
+ thetas_edge = torch.masked_select(thetas_img, edges_mask)
+ minlen = thetas_edge.size()[0]
+
+ # find anchor points (i.e, edge points)
+ sample_num = minlen
+ index_anchors = torch.randint(0, minlen, (sample_num,), dtype=torch.long, device="cuda")
+ theta_anchors = torch.gather(thetas_edge, 0, index_anchors)
+ row_anchors, col_anchors = ind2sub(edges_loc[index_anchors].squeeze(1), w)
+ ## compute the coordinates of 4-points, distances are from [2, 30]
+ distance_matrix = torch.randint(3, 20, (4, sample_num), device="cuda")
+ pos_or_neg = torch.ones(4, sample_num, device="cuda")
+ pos_or_neg[:2, :] = -pos_or_neg[:2, :]
+ distance_matrix = distance_matrix.float() * pos_or_neg
+ col = (
+ col_anchors.unsqueeze(0).expand(4, sample_num).long()
+ + torch.round(
+ distance_matrix.float() * torch.abs(torch.cos(theta_anchors)).unsqueeze(0)
+ ).long()
+ )
+ row = (
+ row_anchors.unsqueeze(0).expand(4, sample_num).long()
+ + torch.round(
+ distance_matrix.float() * torch.abs(torch.sin(theta_anchors)).unsqueeze(0)
+ ).long()
+ )
+
+ # constrain 0= w - 1] = w - 1
+ row[row < 0] = 0
+ row[row > h - 1] = h - 1
+
+ # a-b, b-c, c-d
+ a = sub2ind(row[0, :], col[0, :], w)
+ b = sub2ind(row[1, :], col[1, :], w)
+ c = sub2ind(row[2, :], col[2, :], w)
+ d = sub2ind(row[3, :], col[3, :], w)
+ A = torch.cat((a, b, c), 0)
+ B = torch.cat((b, c, d), 0)
+
+
+
+ inputs_A = inputs[:, A]
+ inputs_B = inputs[:, B]
+ targets_A = targets[:, A]
+ targets_B = targets[:, B]
+ masks_A = torch.gather(masks, 0, A.long())
+ masks_B = torch.gather(masks, 0, B.long())
+
+ # create A, B, C, D mask for visualization
+ # vis_mask = masks.reshape(h, w).cpu().numpy()
+ # vis_row = row.cpu()
+ # vis_col = col.cpu()
+ # visual_A = np.zeros((h, w)).astype(np.bool)
+ # visual_B = np.zeros_like(visual_A)
+ # visual_C = np.zeros_like(visual_A)
+ # visual_D = np.zeros_like(visual_A)
+ # visual_A[vis_row[0, :], vis_col[0, :]] = True
+ # visual_B[vis_row[1, :], vis_col[1, :]] = True
+ # visual_C[vis_row[2, :], vis_col[2, :]] = True
+ # visual_D[vis_row[3, :], vis_col[3, :]] = True
+ # visual_ABCD = [visual_A & vis_mask, visual_B & vis_mask,
+ # visual_C& vis_mask, visual_D& vis_mask]
+ return (
+ inputs_A,
+ inputs_B,
+ targets_A,
+ targets_B,
+ masks_A,
+ masks_B,
+ sample_num,
+ row,
+ col,
+ )
+
+
+######################################################
+# EdgeguidedNormalRankingLoss
+#####################################################
+class EdgeguidedNormalLoss(nn.Module):
+ def __init__(
+ self,
+ point_pairs=10000,
+ cos_theta1=0.25,
+ cos_theta2=0.98,
+ cos_theta3=0.5,
+ cos_theta4=0.86,
+ mask_value=1e-8,
+ loss_weight=1.0,
+ data_type=['stereo', 'denselidar', 'denselidar_nometric','denselidar_syn'],
+ **kwargs
+ ):
+ super(EdgeguidedNormalLoss, self).__init__()
+ self.point_pairs = point_pairs # number of point pairs
+ self.mask_value = mask_value
+ self.cos_theta1 = cos_theta1 # 75 degree
+ self.cos_theta2 = cos_theta2 # 10 degree
+ self.cos_theta3 = cos_theta3 # 60 degree
+ self.cos_theta4 = cos_theta4 # 30 degree
+ # self.kernel = torch.tensor(
+ # np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]], dtype=np.float32),
+ # requires_grad=False,
+ # )[None, None, :, :].cuda()
+ self.depth2normal = Depth2Normal()
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+ self.eps = 1e-6
+
+
+ def getEdge(self, images):
+ n, c, h, w = images.size()
+ a = (
+ torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device="cuda")
+ .contiguous()
+ .view((1, 1, 3, 3))
+ .repeat(1, 1, 1, 1)
+ )
+ b = (
+ torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32, device="cuda")
+ .contiguous()
+ .view((1, 1, 3, 3))
+ .repeat(1, 1, 1, 1)
+ )
+ if c == 3:
+ gradient_x = F.conv2d(images[:, 0, :, :].unsqueeze(1), a)
+ gradient_y = F.conv2d(images[:, 0, :, :].unsqueeze(1), b)
+ else:
+ gradient_x = F.conv2d(images, a)
+ gradient_y = F.conv2d(images, b)
+ edges = torch.sqrt(torch.pow(gradient_x, 2) + torch.pow(gradient_y, 2))
+ edges = F.pad(edges, (1, 1, 1, 1), "constant", 0)
+ thetas = torch.atan2(gradient_y, gradient_x)
+ thetas = F.pad(thetas, (1, 1, 1, 1), "constant", 0)
+ return edges, thetas
+
+ def getNormalEdge(self, normals):
+ n, c, h, w = normals.size()
+ a = (
+ torch.Tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device="cuda")
+ .contiguous()
+ .view((1, 1, 3, 3))
+ .repeat(3, 1, 1, 1)
+ )
+ b = (
+ torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32, device="cuda")
+ .contiguous()
+ .view((1, 1, 3, 3))
+ .repeat(3, 1, 1, 1)
+ )
+ gradient_x = torch.abs(F.conv2d(normals, a, groups=c))
+ gradient_y = torch.abs(F.conv2d(normals, b, groups=c))
+ gradient_x = gradient_x.mean(dim=1, keepdim=True)
+ gradient_y = gradient_y.mean(dim=1, keepdim=True)
+ edges = torch.sqrt(torch.pow(gradient_x, 2) + torch.pow(gradient_y, 2))
+ edges = F.pad(edges, (1, 1, 1, 1), "constant", 0)
+ thetas = torch.atan2(gradient_y, gradient_x)
+ thetas = F.pad(thetas, (1, 1, 1, 1), "constant", 0)
+ return edges, thetas
+
+ def visual_check(self, rgb, samples):
+ import os
+ import matplotlib.pyplot as plt
+ rgb = rgb.cpu().squeeze().numpy()
+
+ mean = np.array([123.675, 116.28, 103.53])[:, np.newaxis, np.newaxis]
+ std= np.array([58.395, 57.12, 57.375])[:, np.newaxis, np.newaxis]
+
+ rgb = ((rgb * std) + mean).astype(np.uint8).transpose((1, 2, 0))
+ mask_A, mask_B, mask_C, mask_D = samples
+ rgb[mask_A.astype(np.bool)] = [255, 0, 0]
+ rgb[mask_B.astype(np.bool)] = [0, 255, 0]
+ rgb[mask_C.astype(np.bool)] = [0, 0, 255]
+ rgb[mask_D.astype(np.bool)] = [255, 255, 0]
+
+ filename = str(np.random.randint(10000))
+ save_path = os.path.join('test_ranking', filename + '.png')
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ plt.imsave(save_path, rgb)
+
+ def forward(self, prediction, target, mask, input, intrinsic, **kwargs):
+ loss = self.get_loss(prediction, target, mask, input, intrinsic, **kwargs)
+ return loss
+
+ def get_loss(self, prediction, target, mask, input, intrinsic, **kwargs):
+ """
+ input and target: surface normal input
+ input: rgb images
+ """
+ gt_depths = target
+
+ if 'predictions_normals' not in kwargs:
+ predictions_normals, _ = self.depth2normal(prediction, intrinsic, mask)
+ targets_normals, targets_normals_masks = self.depth2normal(target, intrinsic, mask)
+ else:
+ predictions_normals = kwargs['predictions_normals']
+ targets_normals = kwargs['targets_normals']
+ targets_normals_masks = kwargs['targets_normals_masks']
+ masks_normals = mask & targets_normals_masks
+
+ # find edges from RGB
+ edges_img, thetas_img = self.getEdge(input)
+
+ # find edges from normals
+ # edges_normal, thetas_normal = self.getNormalEdge(targets_normals)
+ #mask_img_border = torch.ones_like(edges_normal) # normals on the borders
+ #mask_img_border[:, :, 5:-5, 5:-5] = 0
+ # edges_normal[~targets_normals_masks] = 0
+
+ # find edges from depth
+ edges_depth, thetas_depth = self.getEdge(gt_depths)
+ # edges_depth_mask = edges_depth.ge(edges_depth.max() * 0.1)
+ # edges_mask_dilate = torch.clamp(
+ # torch.nn.functional.conv2d(
+ # edges_depth_mask.float(), self.kernel, padding=(1, 1)
+ # ),
+ # 0,
+ # 1,
+ # ).bool()
+ # edges_normal[edges_mask_dilate] = 0
+ # edges_img[edges_mask_dilate] = 0
+
+ # =============================
+ n, c, h, w = targets_normals.size()
+
+ predictions_normals = predictions_normals.contiguous().view(n, c, -1)
+ targets_normals = targets_normals.contiguous().view(n, c, -1)
+ masks_normals = masks_normals.contiguous().view(n, -1)
+ edges_img = edges_img.contiguous().view(n, -1)
+ thetas_img = thetas_img.contiguous().view(n, -1)
+ # edges_normal = edges_normal.view(n, -1)
+ # thetas_normal = thetas_normal.view(n, -1)
+ edges_depth = edges_depth.contiguous().view(n, -1)
+ thetas_depth = thetas_depth.contiguous().view(n, -1)
+
+ # # initialization
+ losses = 0.0
+ valid_samples = 0.0
+ for i in range(n):
+ # Edge-Guided sampling
+ (
+ inputs_A,
+ inputs_B,
+ targets_A,
+ targets_B,
+ masks_A,
+ masks_B,
+ sample_num,
+ row_img,
+ col_img,
+ ) = edgeGuidedSampling(
+ predictions_normals[i, :],
+ targets_normals[i, :],
+ edges_img[i],
+ thetas_img[i],
+ masks_normals[i, :],
+ h,
+ w,
+ )
+ # Depth-Guided sampling
+ # (
+ # depth_inputs_A,
+ # depth_inputs_B,
+ # depth_targets_A,
+ # depth_targets_B,
+ # depth_masks_A,
+ # depth_masks_B,
+ # depth_sample_num,
+ # row_img,
+ # col_img,
+ # ) = edgeGuidedSampling(
+ # predictions_normals[i, :],
+ # targets_normals[i, :],
+ # edges_depth[i],
+ # thetas_depth[i],
+ # masks_normals[i, :],
+ # h,
+ # w,
+ # )
+ # Normal-Guided sampling
+ # (
+ # normal_inputs_A,
+ # normal_inputs_B,
+ # normal_targets_A,
+ # normal_targets_B,
+ # normal_masks_A,
+ # normal_masks_B,
+ # normal_sample_num,
+ # row_normal,
+ # col_normal,
+ # ) = edgeGuidedSampling(
+ # predictions_normals[i, :],
+ # targets_normals[i, :],
+ # edges_normal[i],
+ # thetas_normal[i],
+ # masks_normals[i, :],
+ # h,
+ # w,
+ # )
+
+ # Combine EGS + DEGS
+ # inputs_A = torch.cat((inputs_A, depth_inputs_A), 1) #normal_inputs_A
+ # inputs_B = torch.cat((inputs_B, depth_inputs_B), 1) # normal_inputs_B
+ # targets_A = torch.cat((targets_A, depth_targets_A), 1) #normal_targets_A
+ # targets_B = torch.cat((targets_B, depth_targets_B), 1) #normal_targets_B
+ # masks_A = torch.cat((masks_A, depth_masks_A), 0) #normal_masks_A
+ # masks_B = torch.cat((masks_B, depth_masks_B), 0) #normal_masks_B
+
+ # consider forward-backward consistency checking, i.e, only compute losses of point pairs with valid GT
+ consistency_mask = masks_A & masks_B
+
+ # GT ordinal relationship
+ target_cos = torch.sum(targets_A * targets_B, dim=0)
+ input_cos = torch.sum(inputs_A * inputs_B, dim=0)
+
+ losses += torch.sum(torch.abs(torch.ones_like(target_cos)-input_cos) * consistency_mask.float())
+ valid_samples += torch.sum(consistency_mask.float())
+
+ loss = (losses / (valid_samples + self.eps)) * self.loss_weight
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ loss = 0 * torch.sum(prediction)
+ print(f'Pair-wise Normal Regression Loss NAN error, {loss}, valid pix: {valid_samples}')
+ return loss
+
+def tmp_check_normal(normals, masks, depth):
+ import matplotlib.pyplot as plt
+ import os
+ import cv2
+ from mono.utils.visualization import vis_surface_normal
+ vis_normal1 = vis_surface_normal(normals[0, ...].permute(1, 2, 0).detach(), masks[0,...].detach().squeeze())
+ vis_normal2 = vis_surface_normal(normals[1, ...].permute(1, 2, 0).detach(), masks[1,...].detach().squeeze())
+ vis_depth1 = depth[0, ...].detach().cpu().squeeze().numpy()
+ vis_depth2 = depth[1, ...].detach().cpu().squeeze().numpy()
+
+ name = np.random.randint(100000)
+ os.makedirs('test_normal', exist_ok=True)
+ cv2.imwrite(f'test_normal/{name}.png', vis_normal1)
+ cv2.imwrite(f'test_normal/{name + 1}.png', vis_normal2)
+ plt.imsave(f'test_normal/{name}_d.png', vis_depth1)
+ plt.imsave(f'test_normal/{name + 1}_d.png', vis_depth2)
+
+if __name__ == '__main__':
+ ENL = EdgeguidedNormalLoss()
+ depth = np.random.randn(2, 1, 20, 22)
+ intrin = np.array([[300, 0, 10], [0, 300, 10], [0,0,1]])
+ prediction = np.random.randn(2, 1, 20, 22)
+ imgs = np.random.randn(2, 3, 20, 22)
+ intrinsics = np.stack([intrin, intrin], axis=0)
+
+ depth_t = torch.from_numpy(depth).cuda().float()
+ prediction = torch.from_numpy(prediction).cuda().float()
+ intrinsics = torch.from_numpy(intrinsics).cuda().float()
+ imgs = torch.from_numpy(imgs).cuda().float()
+ depth_t = -1 * torch.abs(depth_t)
+
+ loss = ENL(prediction, depth_t, masks=depth_t>0, images=imgs, intrinsic=intrinsics)
+ print(loss)
\ No newline at end of file
diff --git a/training/mono/model/losses/PWN_Planes.py b/training/mono/model/losses/PWN_Planes.py
new file mode 100644
index 0000000000000000000000000000000000000000..2151f677d0fb0a5a920c13a5b46eda4a0f768f92
--- /dev/null
+++ b/training/mono/model/losses/PWN_Planes.py
@@ -0,0 +1,291 @@
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+class PWNPlanesLoss(nn.Module):
+ """
+ Virtual Normal Loss Function.
+ """
+ def __init__(self, delta_cos=0.867, delta_diff_x=0.007,
+ delta_diff_y=0.007, sample_groups=5000, loss_weight=1.0, data_type=['lidar', 'denselidar'], **kwargs):
+ """
+ Virtual normal planes loss, which constrain points to be on the same 3D plane.
+ :para focal_x: folcal length fx
+ :para focal_y: folcal length fy
+ :para input_size: input image size
+ :para delta_cos: a threshold for the angle among three point, three points should not be on the same plane
+ :para delta_diff_x: a threshold for the distance among three points along the x axis
+ :para delta_diff_y: a threshold for the distance among three points along the y axis
+ :para sample_groups: sample groups number, each group with 3 points can construct a plane
+ """
+ super(PWNPlanesLoss, self).__init__()
+ self.delta_cos = delta_cos
+ self.delta_diff_x = delta_diff_x
+ self.delta_diff_y = delta_diff_y
+ self.sample_groups = sample_groups
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+
+ def init_image_coor(self, B, H, W):
+ u = torch.arange(0, H, dtype=torch.float32, device="cuda").contiguous().view(1, H, 1).expand(1, H, W) # [1, H, W]
+ v = torch.arange(0, W, dtype=torch.float32, device="cuda").contiguous().view(1, 1, W).expand(1, H, W) # [1, H, W]
+ ones = torch.ones((1, H, W), dtype=torch.float32, device="cuda")
+ pixel_coords = torch.stack((u, v, ones), dim=1).expand(B, 3, H, W) # [B, 3, H, W]
+ # self.register_buffer('uv', pixel_coords, persistent=False)
+ self.uv = pixel_coords
+
+ def upproj_pcd(self, depth, intrinsics_inv):
+ """Transform coordinates in the pixel frame to the camera frame.
+ Args:
+ depth: depth maps -- [B, 1, H, W]
+ intrinsics_inv: intrinsics_inv matrix for each element of batch -- [B, 3, 3]
+ Returns:
+ array of (u,v,1) cam coordinates -- [B, 3, H, W]
+ """
+ b, _, h, w = depth.size()
+ assert self.uv.shape[0] == b
+ current_pixel_coords = self.uv.reshape(b, 3, -1) # [B, 3, H*W]
+ cam_coords = (intrinsics_inv @ current_pixel_coords)
+ cam_coords = cam_coords.reshape(b, 3, h, w)
+ out = depth * cam_coords
+ return out
+
+ # def transfer_xyz(self, depth):
+ # x = self.u_u0 * torch.abs(depth) / self.focal_length
+ # y = self.v_v0 * torch.abs(depth) / self.focal_length
+ # z = depth
+ # pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1).contiguous() # [b, h, w, c]
+ # return pw
+
+ # def transfer_uvz(self, depth):
+ # max_uv = self.u_u0.max()
+ # u = self.u_u0.repeat((depth.shape[0], 1, 1, 1)) / max_uv
+ # v = self.v_v0.repeat((depth.shape[0], 1, 1, 1)) / max_uv
+ # z = depth
+ # pw = torch.cat([u, v, z], 1).permute(0, 2, 3, 1).contiguous() # [b, h, w, c]
+ # return pw
+
+ def select_index(self, mask_kp):
+ x, _, h, w = mask_kp.shape
+
+ select_size = int(3 * self.sample_groups)
+ p1_x = []
+ p1_y = []
+ p2_x = []
+ p2_y = []
+ p3_x = []
+ p3_y = []
+ valid_batch = torch.ones((x, 1), dtype=torch.bool, device="cuda")
+ for i in range(x):
+ mask_kp_i = mask_kp[i, 0, :, :]
+ valid_points = torch.nonzero(mask_kp_i)
+
+ if valid_points.shape[0] < select_size * 0.6:
+ valid_points = torch.nonzero(~mask_kp_i.to(torch.uint8))
+ valid_batch[i, :] = False
+ elif valid_points.shape[0] < select_size:
+ repeat_idx = torch.randperm(valid_points.shape[0], device="cuda")[:select_size - valid_points.shape[0]]
+ valid_repeat = valid_points[repeat_idx]
+ valid_points = torch.cat((valid_points, valid_repeat), 0)
+ else:
+ valid_points = valid_points
+ """
+
+ if valid_points.shape[0] <= select_size:
+ valid_points = torch.nonzero(~mask_kp_i.to(torch.uint8))
+ valid_batch[i, :] = False
+ """
+ select_indx = torch.randperm(valid_points.size(0), device="cuda")
+
+ p1 = valid_points[select_indx[0:select_size:3]]
+ p2 = valid_points[select_indx[1:select_size:3]]
+ p3 = valid_points[select_indx[2:select_size:3]]
+
+ p1_x.append(p1[:, 1])
+ p1_y.append(p1[:, 0])
+
+ p2_x.append(p2[:, 1])
+ p2_y.append(p2[:, 0])
+
+ p3_x.append(p3[:, 1])
+ p3_y.append(p3[:, 0])
+ p123 = {'p1_x': torch.stack(p1_x), 'p1_y': torch.stack(p1_y),
+ 'p2_x': torch.stack(p2_x), 'p2_y': torch.stack(p2_y),
+ 'p3_x': torch.stack(p3_x), 'p3_y': torch.stack(p3_y),
+ 'valid_batch': valid_batch}
+ return p123
+
+ def form_pw_groups(self, p123, pw):
+ """
+ Form 3D points groups, with 3 points in each grouup.
+ :param p123: points index
+ :param pw: 3D points, # [1, h, w, c]
+ :return:
+ """
+ p1_x = p123['p1_x']
+ p1_y = p123['p1_y']
+ p2_x = p123['p2_x']
+ p2_y = p123['p2_y']
+ p3_x = p123['p3_x']
+ p3_y = p123['p3_y']
+ batch_list = torch.arange(0, p1_x.shape[0], device="cuda")[:, None]
+ pw = pw.repeat((p1_x.shape[0], 1, 1, 1))
+ pw1 = pw[batch_list, p1_y, p1_x, :]
+ pw2 = pw[batch_list, p2_y, p2_x, :]
+ pw3 = pw[batch_list, p3_y, p3_x, :]
+
+ # [B, N, 3(x,y,z), 3(p1,p2,p3)]
+ pw_groups = torch.cat([pw1[:, :, :, None], pw2[:, :, :, None], pw3[:, :, :, None]], 3)
+ return pw_groups
+
+ def filter_mask(self, pw_pred):
+ """
+ :param pw_pred: constructed 3d vector (x, y, disp), [B, N, 3(x,y,z), 3(p1,p2,p3)]
+ """
+ xy12 = pw_pred[:, :, 0:2, 1] - pw_pred[:, :, 0:2, 0]
+ xy13 = pw_pred[:, :, 0:2, 2] - pw_pred[:, :, 0:2, 0]
+ xy23 = pw_pred[:, :, 0:2, 2] - pw_pred[:, :, 0:2, 1]
+ # Ignore linear
+ xy_diff = torch.cat([xy12[:, :, :, np.newaxis], xy13[:, :, :, np.newaxis], xy23[:, :, :, np.newaxis]],
+ 3) # [b, n, 2(xy), 3]
+ m_batchsize, groups, coords, index = xy_diff.shape
+ proj_query = xy_diff.contiguous().view(m_batchsize * groups, -1, index).permute(0, 2, 1).contiguous() # [bn, 3(p123), 2(xy)]
+ proj_key = xy_diff.contiguous().view(m_batchsize * groups, -1, index) # [bn, 2(xy), 3(p123)]
+ q_norm = proj_query.norm(2, dim=2) # [bn, 3(p123)]
+ nm = torch.bmm(q_norm.contiguous().view(m_batchsize * groups, index, 1), q_norm.contiguous().view(m_batchsize * groups, 1, index)) # []
+ energy = torch.bmm(proj_query, proj_key) # transpose check [bn, 3(p123), 3(p123)]
+ norm_energy = energy / (nm + 1e-8)
+ norm_energy = norm_energy.contiguous().view(m_batchsize * groups, -1) # [bn, 9(p123)]
+ mask_cos = torch.sum((norm_energy > self.delta_cos) + (norm_energy < -self.delta_cos), 1) > 3 # igonre
+ mask_cos = mask_cos.contiguous().view(m_batchsize, groups) # [b, n] # igonre
+
+ #ignore near
+ mask_x = torch.sum(torch.abs(xy_diff[:, :, 0, :]) < self.delta_diff_x, 2) > 0
+ mask_y = torch.sum(torch.abs(xy_diff[:, :, 1, :]) < self.delta_diff_y, 2) > 0
+ mask_near = mask_x & mask_y
+ mask_valid_pts = ~(mask_cos | mask_near)
+ return mask_valid_pts
+
+ def select_points_groups(self, pcd_bi, mask_kp):
+ p123 = self.select_index(mask_kp) # p1_x: [x, n]
+ pcd_bi = pcd_bi.permute((0, 2, 3, 1)).contiguous() #[1, h, w, 3(xyz)]
+ groups_pred = self.form_pw_groups(p123, pcd_bi) # [x, N, 3(x,y,z), 3(p1,p2,p3)]
+
+ # mask:[x, n]
+ mask_valid_pts = (self.filter_mask(groups_pred)).to(torch.bool) # [x, n]
+ mask_valid_batch = p123['valid_batch'].repeat(1, mask_valid_pts.shape[1]) # [x, n]
+ mask_valid = mask_valid_pts & mask_valid_batch # [x, n]
+ return groups_pred, mask_valid
+
+ def constrain_a_plane_loss(self, pw_groups_pre_i, mask_valid_i):
+ """
+ pw_groups_pre: selected points groups for the i-th plane, [N, 3(x,y,z), 3(p1,p2,p3)]
+ """
+ if torch.sum(mask_valid_i) < 2:
+ return 0.0 * torch.sum(pw_groups_pre_i), 0
+ pw_groups_pred_i = pw_groups_pre_i[mask_valid_i] # [n, 3, 3]
+ p12 = pw_groups_pred_i[:, :, 1] - pw_groups_pred_i[:, :, 0]
+ p13 = pw_groups_pred_i[:, :, 2] - pw_groups_pred_i[:, :, 0]
+ virtual_normal = torch.cross(p12, p13, dim=1) # [n, 3]
+ norm = torch.norm(virtual_normal, 2, dim=1, keepdim=True)
+ virtual_normal = virtual_normal / (norm + 1e-8)
+
+ # re-orient normals consistently
+ orient_mask = torch.sum(torch.squeeze(virtual_normal) * torch.squeeze(pw_groups_pred_i[:, :, 0]), dim=1) > 0
+ virtual_normal[orient_mask] *= -1
+ #direct = virtual_normal[:, 2] / torch.abs(virtual_normal[:, 2])
+ #virtual_normal = virtual_normal / direct[:, None] # [n, 3]
+
+ aver_normal = torch.sum(virtual_normal, dim=0)
+ aver_norm = torch.norm(aver_normal, 2, dim=0, keepdim=True)
+ aver_normal = aver_normal / (aver_norm + 1e-5) # [3]
+
+ cos_diff = 1.0 - torch.sum(virtual_normal * aver_normal, dim=1)
+ loss_sum = torch.sum(cos_diff, dim=0)
+ valid_num = cos_diff.numel()
+ return loss_sum, valid_num
+
+ def get_loss(self, pred_depth, gt_depth, ins_planes_mask, intrinsic=None):
+ """
+ Co-plane loss. Enforce points residing on the same instance plane to be co-plane.
+ :param pred_depth: predicted depth map, [B,C,H,W]
+ :param mask: mask for planes, each plane is noted with a value, [B, C, H, W]
+ :param focal_length: focal length
+ """
+ if pred_depth.ndim==3:
+ pred_depth = pred_depth[None, ...]
+ if gt_depth.ndim == 3:
+ gt_depth = gt_depth[None, ...]
+ if ins_planes_mask.ndim == 3:
+ ins_planes_mask = ins_planes_mask[None, ...]
+
+ B, _, H, W = pred_depth.shape
+ loss_sum = torch.tensor(0.0, device="cuda")
+ valid_planes_num = 0
+
+ #if 'uv' not in self._buffers or ('uv' in self._buffers and self.uv.shape[0] != B):
+ self.init_image_coor(B, H, W)
+ pcd = self.upproj_pcd(pred_depth, intrinsic.inverse())
+
+ for i in range(B):
+ mask_i = ins_planes_mask[i, :][None, :, :]
+ unique_planes = torch.unique(mask_i)
+ planes = [mask_i == m for m in unique_planes if m != 0] #[x, 1, h, w] x is the planes number
+ if len(planes) == 0:
+ continue
+ mask_planes = torch.cat(planes, dim=0) #torch.stack(planes, dim=0) #
+ pcd_grops_pred, mask_valid = self.select_points_groups(pcd[i, ...][None, :, :, :], mask_planes) # [x, N, 3(x,y,z), 3(p1,p2,p3)]
+
+ for j in range(unique_planes.numel()-1):
+ mask_valid_j = mask_valid[j, :]
+ pcd_grops_pred_j = pcd_grops_pred[j, :]
+ loss_tmp, valid_angles = self.constrain_a_plane_loss(pcd_grops_pred_j, mask_valid_j)
+ valid_planes_num += valid_angles
+ loss_sum += loss_tmp
+
+ loss = loss_sum / (valid_planes_num + 1e-6) * self.loss_weight
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ loss = torch.sum(pred_depth) * 0
+ print(f'PWNPlane NAN error, {loss}')
+ return loss
+
+ def forward(self, prediction, target, mask, intrinsic, **kwargs): #gt_depth, pred_depth, select=True):
+ """
+ Virtual normal loss.
+ :param prediction: predicted depth map, [B,W,H,C]
+ :param data: target label, ground truth depth, [B, W, H, C], padding region [padding_up, padding_down]
+ :return:
+ """
+ dataset = kwargs['dataset']
+ batch_mask = np.array(dataset) == 'Taskonomy'
+ if np.sum(batch_mask) == 0:
+ return torch.sum(prediction) * 0.0
+ ins_planes_mask = kwargs['sem_mask'] #
+ assert ins_planes_mask.ndim == 4
+ loss = self.get_loss(
+ prediction[batch_mask],
+ target[batch_mask],
+ ins_planes_mask[batch_mask],
+ intrinsic[batch_mask],
+ )
+ return loss
+
+
+if __name__ == '__main__':
+ import cv2
+ vnl_loss = PWNPlanesLoss()
+ pred_depth = torch.rand([2, 1, 385, 513]).cuda()
+ gt_depth = torch.rand([2, 1, 385, 513]).cuda()
+ gt_depth[:, :, 3:20, 40:60] = 0
+ mask_kp1 = pred_depth > 0.9
+ mask_kp2 = pred_depth < 0.5
+ mask = torch.zeros_like(gt_depth, dtype=torch.uint8)
+ mask = 1*mask_kp1 + 2* mask_kp2
+ mask[1,...] = 0
+
+
+ intrinsic = torch.tensor([[100, 0, 50], [0, 100, 50,], [0,0,1]]).cuda().float()
+ intrins = torch.stack([intrinsic, intrinsic], dim=0)
+ loss = vnl_loss(gt_depth, gt_depth, mask, intrins, dataset=np.array(['Taskonomy', 'Taskonomy']))
+ print(loss)
diff --git a/training/mono/model/losses/Ranking.py b/training/mono/model/losses/Ranking.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cb1eecb2ea5fbfe3b73aa49afa54adf574cb02e
--- /dev/null
+++ b/training/mono/model/losses/Ranking.py
@@ -0,0 +1,342 @@
+import torch
+from torch import nn
+import numpy as np
+import torch.nn.functional as F
+import matplotlib.pyplot as plt
+import os
+
+"""
+Sampling strategies: RS (Random Sampling), EGS (Edge-Guided Sampling), and IGS (Instance-Guided Sampling)
+"""
+###########
+# RANDOM SAMPLING
+# input:
+# predictions[i,:], targets[i, :], masks[i, :], self.mask_value, self.point_pairs
+# return:
+# inputs_A, inputs_B, targets_A, targets_B, consistent_masks_A, consistent_masks_B
+###########
+def randomSampling(predictions, targets, masks, threshold, sample_num):
+
+ # find A-B point pairs from predictions
+ inputs_index = torch.masked_select(predictions, targets.gt(threshold))
+ num_effect_pixels = len(inputs_index)
+ shuffle_effect_pixels = torch.randperm(num_effect_pixels, device="cuda")
+ inputs_A = inputs_index[shuffle_effect_pixels[0:sample_num*2:2]]
+ inputs_B = inputs_index[shuffle_effect_pixels[1:sample_num*2:2]]
+ # find corresponding pairs from GT
+ target_index = torch.masked_select(targets, targets.gt(threshold))
+ targets_A = target_index[shuffle_effect_pixels[0:sample_num*2:2]]
+ targets_B = target_index[shuffle_effect_pixels[1:sample_num*2:2]]
+ # only compute the losses of point pairs with valid GT
+ consistent_masks_index = torch.masked_select(masks, targets.gt(threshold))
+ consistent_masks_A = consistent_masks_index[shuffle_effect_pixels[0:sample_num*2:2]]
+ consistent_masks_B = consistent_masks_index[shuffle_effect_pixels[1:sample_num*2:2]]
+
+ # The amount of A and B should be the same!!
+ if len(targets_A) > len(targets_B):
+ targets_A = targets_A[:-1]
+ inputs_A = inputs_A[:-1]
+ consistent_masks_A = consistent_masks_A[:-1]
+
+ return inputs_A, inputs_B, targets_A, targets_B, consistent_masks_A, consistent_masks_B
+
+###########
+# EDGE-GUIDED SAMPLING
+# input:
+# predictions[i,:], targets[i, :], masks[i, :], edges_img[i], thetas_img[i], masks[i, :], h, w
+# return:
+# inputs_A, inputs_B, targets_A, targets_B, masks_A, masks_B
+###########
+def ind2sub(idx, cols):
+ r = torch.div(idx, cols, rounding_mode='floor') #idx // cols
+ c = idx % cols
+ return r, c
+
+def sub2ind(r, c, cols):
+ idx = (r * cols + c).int()
+ return idx
+
+def edgeGuidedSampling(predictions, targets, edges_img, thetas_img, masks, h, w):
+
+ # find edges
+ edges_max = edges_img.max()
+ edges_mask = edges_img.ge(edges_max*0.1)
+ edges_loc = edges_mask.nonzero()
+
+ inputs_edge = torch.masked_select(predictions, edges_mask)
+ targets_edge = torch.masked_select(targets, edges_mask)
+ thetas_edge = torch.masked_select(thetas_img, edges_mask)
+ minlen = inputs_edge.size()[0]
+
+ # find anchor points (i.e, edge points)
+ sample_num = minlen
+ index_anchors = torch.randint(0, minlen, (sample_num,), dtype=torch.long, device="cuda")
+ anchors = torch.gather(inputs_edge, 0, index_anchors)
+ theta_anchors = torch.gather(thetas_edge, 0, index_anchors)
+ row_anchors, col_anchors = ind2sub(edges_loc[index_anchors].squeeze(1), w)
+ ## compute the coordinates of 4-points, distances are from [2, 30]
+ distance_matrix = torch.randint(2, 40, (4,sample_num), device="cuda")
+ pos_or_neg = torch.ones(4, sample_num, device="cuda")
+ pos_or_neg[:2,:] = -pos_or_neg[:2,:]
+ distance_matrix = distance_matrix.float() * pos_or_neg
+ col = col_anchors.unsqueeze(0).expand(4, sample_num).long() + torch.round(distance_matrix.float() * torch.abs(torch.cos(theta_anchors)).unsqueeze(0)).long()
+ row = row_anchors.unsqueeze(0).expand(4, sample_num).long() + torch.round(distance_matrix.float() * torch.abs(torch.sin(theta_anchors)).unsqueeze(0)).long()
+
+ # constrain 0=w-1] = w-1
+ row[row<0] = 0
+ row[row>h-1] = h-1
+
+ # a-b, b-c, c-d
+ a = sub2ind(row[0,:], col[0,:], w)
+ b = sub2ind(row[1,:], col[1,:], w)
+ c = sub2ind(row[2,:], col[2,:], w)
+ d = sub2ind(row[3,:], col[3,:], w)
+ A = torch.cat((a,b,c), 0)
+ B = torch.cat((b,c,d), 0)
+
+ inputs_A = torch.gather(predictions, 0, A.long())
+ inputs_B = torch.gather(predictions, 0, B.long())
+ targets_A = torch.gather(targets, 0, A.long())
+ targets_B = torch.gather(targets, 0, B.long())
+ masks_A = torch.gather(masks, 0, A.long())
+ masks_B = torch.gather(masks, 0, B.long())
+
+ # create A, B, C, D mask for visualization
+ # vis_mask = masks.reshape(h, w).cpu().numpy()
+ # vis_row = row.cpu()
+ # vis_col = col.cpu()
+ # visual_A = np.zeros((h, w)).astype(np.bool)
+ # visual_B = np.zeros_like(visual_A)
+ # visual_C = np.zeros_like(visual_A)
+ # visual_D = np.zeros_like(visual_A)
+ # visual_A[vis_row[0, :], vis_col[0, :]] = True
+ # visual_B[vis_row[1, :], vis_col[1, :]] = True
+ # visual_C[vis_row[2, :], vis_col[2, :]] = True
+ # visual_D[vis_row[3, :], vis_col[3, :]] = True
+ # visual_ABCD = [visual_A & vis_mask, visual_B & vis_mask,
+ # visual_C& vis_mask, visual_D& vis_mask]
+ return inputs_A, inputs_B, targets_A, targets_B, masks_A, masks_B, sample_num
+
+
+######################################################
+# Ranking loss (Random sampling)
+#####################################################
+class RankingLoss(nn.Module):
+ def __init__(self, point_pairs=5000, sigma=0.03, alpha=1.0, mask_value=-1e-8, loss_weight=1, **kwargs):
+ super(RankingLoss, self).__init__()
+ self.point_pairs = point_pairs # number of point pairs
+ self.sigma = sigma # used for determining the ordinal relationship between a selected pair
+ self.alpha = alpha # used for balancing the effect of = and (<,>)
+ self.mask_value = mask_value
+ self.loss_weight = loss_weight
+ self.eps = 1e-6
+
+ def forward(self, prediction, target, mask=None, **kwargs):
+ n,c,h,w = target.size()
+ if mask == None:
+ mask = target > self.mask_value
+ if n != 1:
+ prediction = prediction.view(n, -1)#.double()
+ target = target.view(n, -1)#.double()
+ mask = mask.view(n, -1)#.double()
+ else:
+ prediction = prediction.contiguous().view(1, -1)#.double()
+ target = target.contiguous().view(1, -1)#.double()
+ mask = mask.contiguous().view(1, -1)#.double()
+
+ loss = 0.0 #torch.tensor([0.0]).cuda()
+ valid_samples = 0
+ for i in range(n):
+ # find A-B point pairs
+ inputs_A, inputs_B, targets_A, targets_B, consistent_masks_A, consistent_masks_B = randomSampling(prediction[i,:], target[i, :], mask[i, :], self.mask_value, self.point_pairs)
+
+ #GT ordinal relationship
+ target_ratio = torch.div(targets_A, targets_B+self.eps)
+ mask_eq = target_ratio.lt(1.0 + self.sigma) * target_ratio.gt(1.0/(1.0+self.sigma))
+ labels = torch.zeros_like(target_ratio)
+ labels[target_ratio.ge(1.0 + self.sigma)] = 1
+ labels[target_ratio.le(1.0/(1.0+self.sigma))] = -1
+
+ # consider forward-backward consistency checking, only compute the losses of point pairs with valid GT
+ consistency_mask = consistent_masks_A & consistent_masks_B
+
+ # compute loss
+ equal_loss = (inputs_A - inputs_B).pow(2)[mask_eq & consistency_mask]
+ unequal_loss = torch.log(1 + torch.exp((-inputs_A + inputs_B) * labels))[(~mask_eq) & consistency_mask]
+
+ loss = loss + self.alpha * equal_loss.sum() + unequal_loss.sum()
+ valid_samples = valid_samples + unequal_loss.numel() + equal_loss.numel()
+ loss = loss / (valid_samples + self.eps)
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ raise RuntimeError(f'VNL error, {loss}')
+ return loss * self.loss_weight
+
+
+
+
+
+######################################################
+# EdgeguidedRankingLoss (with regularization term)
+# Please comment regularization_loss if you don't want to use multi-scale gradient matching term
+#####################################################
+class EdgeguidedRankingLoss(nn.Module):
+ def __init__(self, point_pairs=5000, sigma=0.03, alpha=1.0, mask_value=1e-6, loss_weight=1.0, data_type=['rel', 'sfm', 'stereo', 'lidar'], **kwargs):
+ super(EdgeguidedRankingLoss, self).__init__()
+ self.point_pairs = point_pairs # number of point pairs
+ self.sigma = sigma # used for determining the ordinal relationship between a selected pair
+ self.alpha = alpha # used for balancing the effect of = and (<,>)
+ self.mask_value = mask_value
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+ self.eps = 1e-6
+
+ def getEdge(self, images):
+ n,c,h,w = images.size()
+ a = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device="cuda").view((1,1,3,3)).repeat(1, 1, 1, 1)
+ b = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32, device="cuda").view((1,1,3,3)).repeat(1, 1, 1, 1)
+ if c == 3:
+ gradient_x = F.conv2d(images[:,0,:,:].unsqueeze(1), a)
+ gradient_y = F.conv2d(images[:,0,:,:].unsqueeze(1), b)
+ else:
+ gradient_x = F.conv2d(images, a)
+ gradient_y = F.conv2d(images, b)
+ edges = torch.sqrt(torch.pow(gradient_x,2)+ torch.pow(gradient_y,2))
+ edges = F.pad(edges, (1,1,1,1), "constant", 0)
+ thetas = torch.atan2(gradient_y, gradient_x)
+ thetas = F.pad(thetas, (1,1,1,1), "constant", 0)
+
+ return edges, thetas
+
+ def visual_check(self, rgb, samples):
+ rgb = rgb.cpu().squeeze().numpy()
+
+ mean = np.array([123.675, 116.28, 103.53])[:, np.newaxis, np.newaxis]
+ std= np.array([58.395, 57.12, 57.375])[:, np.newaxis, np.newaxis]
+
+ rgb = ((rgb * std) + mean).astype(np.uint8).transpose((1, 2, 0))
+ mask_A, mask_B, mask_C, mask_D = samples
+ rgb[mask_A.astype(np.bool)] = [255, 0, 0]
+ rgb[mask_B.astype(np.bool)] = [0, 255, 0]
+ rgb[mask_C.astype(np.bool)] = [0, 0, 255]
+ rgb[mask_D.astype(np.bool)] = [255, 255, 0]
+
+ filename = str(np.random.randint(10000))
+ save_path = os.path.join('test_ranking', filename + '.png')
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ plt.imsave(save_path, rgb)
+
+ def forward(self, prediction, target, mask=None, input=None, **kwargs):
+ loss = self.get_loss(prediction, target, mask, input, **kwargs)
+ return loss
+
+ def get_loss(self, prediction, target, mask=None, input=None, **kwargs):
+ if mask == None:
+ mask = target > self.mask_value
+ # find edges from RGB
+ edges_img, thetas_img = self.getEdge(input)
+ # find edges from target depths
+ edges_depth, thetas_depth = self.getEdge(target)
+
+ #=============================
+ n,c,h,w = target.size()
+ if n != 1:
+ prediction = prediction.view(n, -1)#.double()
+ target = target.view(n, -1)#.double()
+ mask = mask.view(n, -1)#.double()
+ edges_img = edges_img.view(n, -1)#.double()
+ thetas_img = thetas_img.view(n, -1)#.double()
+ edges_depth = edges_depth.view(n, -1)#.double()
+ thetas_depth = thetas_depth.view(n, -1)#.double()
+ else:
+ prediction = prediction.contiguous().view(1, -1)#.double()
+ target = target.contiguous().view(1, -1)#.double()
+ mask = mask.contiguous().view(1, -1)#.double()
+ edges_img = edges_img.contiguous().view(1, -1)#.double()
+ thetas_img = thetas_img.contiguous().view(1, -1)#.double()
+ edges_depth = edges_depth.view(1, -1)#.double()
+ thetas_depth = thetas_depth.view(1, -1)#.double()
+
+ # initialization
+ loss = 0.0 #torch.tensor([0.0]).cuda()
+ valid_samples = 0
+
+ for i in range(n):
+ # Edge-Guided sampling from RGB predictions, targets, edges_img, thetas_img, masks, h, w
+ inputs_A, inputs_B, targets_A, targets_B, masks_A, masks_B, sample_num = edgeGuidedSampling(
+ prediction[i,:],
+ target[i, :],
+ edges_img[i],
+ thetas_img[i],
+ mask[i, :],
+ h,
+ w
+ )
+ # # Edge-Guided sampling from depth
+ # inputs_A_depth, inputs_B_depth, targets_A_depth, targets_B_depth, masks_A_depth, masks_B_depth, sample_num_depth = edgeGuidedSampling(
+ # prediction[i,:],
+ # target[i, :],
+ # edges_depth[i],
+ # thetas_depth[i],
+ # mask[i, :],
+ # h,
+ # w
+ # )
+
+ # Random Sampling predictions, targets, masks, threshold, sample_num
+ random_sample_num = sample_num
+ random_inputs_A, random_inputs_B, random_targets_A, random_targets_B, random_masks_A, random_masks_B = randomSampling(
+ prediction[i,:],
+ target[i, :],
+ mask[i, :],
+ self.mask_value,
+ random_sample_num
+ )
+
+ # Combine EGS + RS + EGS_depth
+ inputs_A_merge = torch.cat((inputs_A, random_inputs_A,), 0)
+ inputs_B_merge = torch.cat((inputs_B, random_inputs_B,), 0)
+ targets_A_merge = torch.cat((targets_A, random_targets_A,), 0)
+ targets_B_merge = torch.cat((targets_B, random_targets_B,), 0)
+ masks_A_merge = torch.cat((masks_A, random_masks_A,), 0)
+ masks_B_merge = torch.cat((masks_B, random_masks_B,), 0)
+
+ #GT ordinal relationship
+ target_ratio = torch.div(targets_A_merge + 1e-6, targets_B_merge + 1e-6)
+ mask_eq = target_ratio.lt(1.0 + self.sigma) & target_ratio.gt(1.0/(1.0+self.sigma))
+ labels = torch.zeros_like(target_ratio)
+ labels[target_ratio.ge(1.0 + self.sigma)] = 1
+ labels[target_ratio.le(1.0/(1.0+self.sigma))] = -1
+
+ # consider forward-backward consistency checking, i.e, only compute losses of point pairs with valid GT
+ consistency_mask = masks_A_merge & masks_B_merge
+
+ equal_loss = (inputs_A_merge - inputs_B_merge).pow(2)[mask_eq & consistency_mask]
+ unequal_loss = torch.log(1 + torch.exp((-inputs_A_merge + inputs_B_merge) * labels))[(~mask_eq) & consistency_mask]
+
+ loss = loss + self.alpha * torch.sum(equal_loss) + torch.sum(unequal_loss)
+ valid_samples = valid_samples + equal_loss.numel()
+ valid_samples = valid_samples + unequal_loss.numel()
+ loss = loss / (valid_samples + self.eps)
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ raise RuntimeError(f'VNL error, {loss}')
+ return loss * self.loss_weight
+
+
+if __name__ == '__main__':
+ import cv2
+
+ rank_loss = EdgeguidedRankingLoss()
+ pred_depth = np.random.randn(2, 1, 480, 640)
+ gt_depth = np.ones((2, 1, 480, 640)) #np.random.randn(2, 1, 480, 640)
+ # gt_depth = cv2.imread('/hardware/yifanliu/SUNRGBD/sunrgbd-meta-data/sunrgbd_test_depth/2.png', -1)
+ # gt_depth = gt_depth[None, :, :, None]
+ # pred_depth = gt_depth[:, :, ::-1, :]
+ gt_depth = torch.tensor(np.asarray(gt_depth, np.float32)).cuda()
+ pred_depth = torch.tensor(np.asarray(pred_depth, np.float32)).cuda()
+ input = np.random.randn(2, 3, 480, 640)
+ input_torch = torch.tensor(np.asarray(input, np.float32)).cuda()
+ loss = rank_loss(gt_depth, gt_depth, gt_depth>0, input=input_torch)
+ print(loss)
diff --git a/training/mono/model/losses/Regularization.py b/training/mono/model/losses/Regularization.py
new file mode 100644
index 0000000000000000000000000000000000000000..f493d1117e707b786e63a92482aabbaf8c79643b
--- /dev/null
+++ b/training/mono/model/losses/Regularization.py
@@ -0,0 +1,18 @@
+import torch
+import torch.nn as nn
+
+class RegularizationLoss(nn.Module):
+ """
+ Enforce losses on pixels without any gts.
+ """
+ def __init__(self, loss_weight=0.1, data_type=['sfm', 'stereo', 'lidar'], **kwargs):
+ super(RegularizationLoss, self).__init__()
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+ self.eps = 1e-6
+
+ def forward(self, prediction, target, mask=None, **kwargs):
+ pred_wo_gt = prediction[~mask]
+ #loss = - torch.sum(pred_wo_gt) / (pred_wo_gt.numel() + 1e-8)
+ loss = 1/ (torch.sum(pred_wo_gt) / (pred_wo_gt.numel() + self.eps))
+ return loss * self.loss_weight
\ No newline at end of file
diff --git a/training/mono/model/losses/SSIL.py b/training/mono/model/losses/SSIL.py
new file mode 100644
index 0000000000000000000000000000000000000000..38135d3a90481f209eaac22aa8cbe1cac8a80990
--- /dev/null
+++ b/training/mono/model/losses/SSIL.py
@@ -0,0 +1,56 @@
+import torch
+import torch.nn as nn
+
+class SSILoss(nn.Module):
+ """
+ Scale shift invariant MAE loss.
+ loss = MAE((d-median(d)/s - (d'-median(d'))/s'), s = mean(d- median(d))
+ """
+ def __init__(self, loss_weight=1, data_type=['sfm', 'stereo', 'lidar'], **kwargs):
+ super(SSILoss, self).__init__()
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+ self.eps = 1e-6
+
+ def ssi_mae(self, target, prediction, mask):
+ valid_pixes = torch.sum(mask) + self.eps
+
+ gt_median = torch.median(target) if target.numel() else 0
+ gt_s = torch.abs(target - gt_median).sum() / valid_pixes
+ gt_trans = (target - gt_median) / (gt_s + self.eps)
+
+ pred_median = torch.median(prediction) if prediction.numel() else 0
+ pred_s = torch.abs(prediction - pred_median).sum() / valid_pixes
+ pred_trans = (prediction - pred_median) / (pred_s + self.eps)
+
+ ssi_mae_sum = torch.sum(torch.abs(gt_trans - pred_trans))
+ return ssi_mae_sum, valid_pixes
+
+ def forward(self, prediction, target, mask=None, **kwargs):
+ """
+ Calculate loss.
+ """
+ B, C, H, W = prediction.shape
+ loss = 0
+ valid_pix = 0
+ for i in range(B):
+ mask_i = mask[i, ...]
+ gt_depth_i = target[i, ...][mask_i]
+ pred_depth_i = prediction[i, ...][mask_i]
+ ssi_sum, valid_pix_i = self.ssi_mae(pred_depth_i, gt_depth_i, mask_i)
+ loss += ssi_sum
+ valid_pix += valid_pix_i
+ loss /= (valid_pix + self.eps)
+ return loss * self.loss_weight
+
+if __name__ == '__main__':
+ torch.manual_seed(1)
+ torch.cuda.manual_seed_all(1)
+
+ ssil = SSILoss()
+ pred = torch.rand((2, 1, 256, 256)).cuda()
+ gt = torch.rand((2, 1, 256, 256)).cuda()#torch.zeros_like(pred).cuda() #
+ gt[:, :, 100:256, 0:100] = -1
+ mask = gt > 0
+ out = ssil(pred, gt, mask)
+ print(out)
diff --git a/training/mono/model/losses/ScaleAlignLoss.py b/training/mono/model/losses/ScaleAlignLoss.py
new file mode 100644
index 0000000000000000000000000000000000000000..ded0509e2d8f971a0093eca3819dbdb6b96f43e7
--- /dev/null
+++ b/training/mono/model/losses/ScaleAlignLoss.py
@@ -0,0 +1,57 @@
+import torch
+import torch.nn as nn
+
+class ScaleAlignLoss(nn.Module):
+ """
+ Loss function defined over sequence of depth predictions
+ """
+ def __init__(self, data_type=['lidar', 'denselidar', 'stereo', 'denselidar_syn'], loss_weight=1.0, disable_dataset=['MapillaryPSD'], **kwargs):
+ super(ScaleAlignLoss, self).__init__()
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+ self.disable_dataset = disable_dataset
+
+ def forward(self, prediction, target, mask, scale, **kwargs):
+ device = target.device
+
+ B, C, H, W = prediction.shape
+
+
+ # median_pred, _ = torch.median(prediction.view(B, C*H*W), 1)
+ # median_pred = median_pred.detach()
+
+ # scale_factor = torch.zeros_like(scale).squeeze(3).squeeze(2).squeeze(1)
+ # for i in range(B):
+ # mask_i = mask[i, ...]
+ # if torch.sum(mask_i) > 10:
+ # scale_factor[i] = torch.median(target[i, ...][mask_i]) / (torch.median(prediction[i, ...][mask_i]) + 1e-8)
+ # else:
+ # scale_factor[i] = 0
+
+ # target_scale = (median_pred * scale_factor)
+
+ # batches_dataset = kwargs['dataset']
+ # self.batch_valid = torch.tensor([1 if batch_dataset not in self.disable_dataset else 0 \
+ # for batch_dataset in batches_dataset], device=device)
+
+ # batch_valid = self.batch_valid * (scale_factor > 1e-8)
+
+ # scale_diff = torch.abs(scale.squeeze(3).squeeze(2).squeeze(1) - scale_factor * median_pred)
+
+ batches_dataset = kwargs['dataset']
+ self.batch_valid = torch.tensor([1 if batch_dataset not in self.disable_dataset else 0 \
+ for batch_dataset in batches_dataset], device=device)
+
+ scale_tgt = torch.zeros_like(scale).squeeze(3).squeeze(2).squeeze(1)
+ for i in range(B):
+ mask_i = mask[i, ...]
+ if torch.sum(mask_i) > 10:
+ scale_tgt[i] = torch.median(target[i, ...][mask_i])
+ else:
+ scale_tgt[i] = 0
+
+ batch_valid = self.batch_valid * (scale_tgt > 1e-8)
+ scale_diff = torch.abs(scale.squeeze(3).squeeze(2).squeeze(1) - scale_tgt)
+ loss = torch.sum(scale_diff * batch_valid) / (torch.sum(batch_valid) + 1e-8)
+
+ return loss * self.loss_weight
\ No newline at end of file
diff --git a/training/mono/model/losses/ScaleInvL1.py b/training/mono/model/losses/ScaleInvL1.py
new file mode 100644
index 0000000000000000000000000000000000000000..fad42d54015a6102c41ae5292141bb0f9d5d3f5e
--- /dev/null
+++ b/training/mono/model/losses/ScaleInvL1.py
@@ -0,0 +1,35 @@
+import torch
+import torch.nn as nn
+
+class ScaleInvL1Loss(nn.Module):
+ """
+ Compute scale-invariant L1 loss.
+ """
+ def __init__(self, loss_weight=1, data_type=['sfm', 'denselidar_nometric', 'denselidar_syn'], **kwargs):
+ super(ScaleInvL1Loss, self).__init__()
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+ self.eps = 1e-6
+
+ def forward(self, prediction, target, mask=None, **kwargs):
+ B, _, _, _ = target.shape
+ target_nan = target.clone()
+ target_nan[~mask] = torch.nan
+ median_target = torch.nanmedian(target_nan.view(B, -1), dim=1)[0]
+ prediction_nan = prediction.clone().detach()
+ prediction_nan[~mask] = torch.nan
+ median_prediction = torch.nanmedian(prediction_nan.view(B, -1), dim=1)[0]
+ scale = median_target / median_prediction
+ scale[torch.isnan(scale)] = 0
+ pred_scale = prediction * scale[:, None, None, None]
+
+ target_valid = target * mask
+ pred_valid = pred_scale * mask
+ diff = torch.abs(pred_valid - target_valid)
+ # disp_diff = diff / (target_valid + self.eps)
+ loss = torch.sum(diff) / (torch.sum(mask) + self.eps)
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ loss = 0 * torch.sum(prediction)
+ print(f'Scale-invariant L1 NAN error, {loss}')
+ #raise RuntimeError(f'Silog error, {loss}, d_square_mean: {d_square_mean}, d_mean: {d_mean}')
+ return loss * self.loss_weight
diff --git a/training/mono/model/losses/SiLog.py b/training/mono/model/losses/SiLog.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce336f9f9a5aea56ad78b29861afb4ae3b0302e6
--- /dev/null
+++ b/training/mono/model/losses/SiLog.py
@@ -0,0 +1,38 @@
+import torch
+import torch.nn as nn
+
+class SilogLoss(nn.Module):
+ """
+ Compute SILog loss. See https://papers.nips.cc/paper/2014/file/7bccfde7714a1ebadf06c5f4cea752c1-Paper.pdf for
+ more information about scale-invariant loss.
+ """
+ def __init__(self, variance_focus=0.5, loss_weight=1, data_type=['stereo', 'lidar'], **kwargs):
+ super(SilogLoss, self).__init__()
+ self.variance_focus = variance_focus
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+ self.eps = 1e-6
+
+ def silog_loss(self, prediction, target, mask):
+ d = torch.log(prediction[mask]) - torch.log(target[mask])
+ d_square_mean = torch.sum(d ** 2) / (d.numel() + self.eps)
+ d_mean = torch.sum(d) / (d.numel() + self.eps)
+ loss = d_square_mean - self.variance_focus * (d_mean ** 2)
+ return loss
+
+ def forward(self, prediction, target, mask=None, **kwargs):
+ if target[mask].numel() > 0:
+ loss = self.silog_loss(prediction, target, mask)
+ else:
+ loss = 0 * torch.sum(prediction)
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ raise RuntimeError(f'Silog error, {loss}, d_square_mean: {d_square_mean}, d_mean: {d_mean}')
+ return loss * self.loss_weight
+
+if __name__ == '__main__':
+ silog = SilogLoss()
+ pred = torch.rand((2, 3, 256, 256)).cuda()
+ gt = torch.zeros_like(pred) #torch.rand((2, 3, 256, 256)).cuda()
+ mask = gt > 0
+ out = silog(pred, gt, mask)
+ print(out)
diff --git a/training/mono/model/losses/SkyRegularization.py b/training/mono/model/losses/SkyRegularization.py
new file mode 100644
index 0000000000000000000000000000000000000000..a548fcc1aefae0cd201cb99956acfee3bb2bc1a7
--- /dev/null
+++ b/training/mono/model/losses/SkyRegularization.py
@@ -0,0 +1,79 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class SkyRegularizationLoss(nn.Module):
+ """
+ Enforce losses on pixels without any gts.
+ """
+ def __init__(self, loss_weight=0.1, data_type=['sfm', 'stereo', 'lidar', 'denselidar', 'denselidar_nometric', 'denselidar_syn'], sky_id=142, sample_ratio=0.4, regress_value=1.8, normal_regress=None, normal_weight=1.0, **kwargs):
+ super(SkyRegularizationLoss, self).__init__()
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+ self.sky_id = sky_id
+ self.sample_ratio = sample_ratio
+ self.eps = 1e-6
+ self.regress_value = regress_value
+ self.normal_regress = normal_regress
+ self.normal_weight = normal_weight
+
+ def loss1(self, pred_sky):
+ loss = 1/ torch.exp((torch.sum(pred_sky) / (pred_sky.numel() + self.eps)))
+ return loss
+
+ def loss2(self, pred_sky):
+ loss = torch.sum(torch.abs(pred_sky - self.regress_value)) / (pred_sky.numel() + self.eps)
+ return loss
+
+ def loss_norm(self, pred_norm, sky_mask):
+ sky_norm = torch.FloatTensor(self.normal_regress).cuda()
+ sky_norm = sky_norm.unsqueeze(0).unsqueeze(2).unsqueeze(3)
+ dot = torch.cosine_similarity(pred_norm[:, :3, :, :].clone(), sky_norm, dim=1)
+
+ sky_mask_float = sky_mask.float().squeeze()
+ valid_mask = sky_mask_float \
+ * (dot.detach() < 0.999).float() \
+ * (dot.detach() > -0.999).float()
+
+ al = (1 - dot) * valid_mask
+ loss = torch.sum(al) / (torch.sum(sky_mask_float) + self.eps)
+ return loss
+
+ def forward(self, prediction, target, prediction_normal=None, mask=None, sem_mask=None, **kwargs):
+ sky_mask = sem_mask == self.sky_id
+ pred_sky = prediction[sky_mask]
+ pred_sky_numel = pred_sky.numel()
+
+ if pred_sky.numel() > 50:
+ samples = np.random.choice(pred_sky_numel, int(pred_sky_numel*self.sample_ratio), replace=False)
+
+ if pred_sky.numel() > 0:
+ #loss = - torch.sum(pred_wo_gt) / (pred_wo_gt.numel() + 1e-8)
+ loss = self.loss2(pred_sky)
+
+ if (prediction_normal != None) and (self.normal_regress != None):
+ loss_normal = self.loss_norm(prediction_normal, sky_mask)
+ loss = loss + loss_normal * self.normal_weight
+
+ else:
+ loss = torch.sum(prediction) * 0
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ loss = torch.sum(prediction) * 0
+ print(f'SkyRegularization NAN error, {loss}')
+ # raise RuntimeError(f'Sky Loss error, {loss}')
+
+ return loss * self.loss_weight
+
+if __name__ == '__main__':
+ import cv2
+ sky = SkyRegularizationLoss()
+ pred_depth = np.random.random([2, 1, 480, 640])
+ gt_depth = np.zeros_like(pred_depth) #np.random.random([2, 1, 480, 640])
+ intrinsic = [[[100, 0, 200], [0, 100, 200], [0, 0, 1]], [[100, 0, 200], [0, 100, 200], [0, 0, 1]],]
+ gt_depth = torch.tensor(np.array(gt_depth, np.float32)).cuda()
+ pred_depth = torch.tensor(np.array(pred_depth, np.float32)).cuda()
+ intrinsic = torch.tensor(np.array(intrinsic, np.float32)).cuda()
+ mask = gt_depth > 0
+ loss1 = sky(pred_depth, gt_depth, mask, mask, intrinsic)
+ print(loss1)
\ No newline at end of file
diff --git a/training/mono/model/losses/VNL.py b/training/mono/model/losses/VNL.py
new file mode 100644
index 0000000000000000000000000000000000000000..111b1ae690709417b2d9d15e6d930bfa69d4465e
--- /dev/null
+++ b/training/mono/model/losses/VNL.py
@@ -0,0 +1,260 @@
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+class VNLoss(nn.Module):
+ """
+ Virtual Normal Loss.
+ """
+ def __init__(self,
+ delta_cos=0.867, delta_diff_x=0.01,
+ delta_diff_y=0.01, delta_diff_z=0.01,
+ delta_z=1e-5, sample_ratio=0.15,
+ loss_weight=1.0, data_type=['sfm', 'stereo', 'lidar', 'denselidar', 'denselidar_nometric', 'denselidar_syn'], **kwargs):
+ super(VNLoss, self).__init__()
+ self.delta_cos = delta_cos
+ self.delta_diff_x = delta_diff_x
+ self.delta_diff_y = delta_diff_y
+ self.delta_diff_z = delta_diff_z
+ self.delta_z = delta_z
+ self.sample_ratio = sample_ratio
+ self.loss_weight = loss_weight
+ self.data_type = data_type
+ self.eps = 1e-6
+
+
+ def init_image_coor(self, intrinsic, height, width):
+ # x_row = torch.arange(0, W, device="cuda")
+ # x = torch.tile(x_row, (H, 1))
+ # x = x.to(torch.float32)
+ # u_m_u0 = x[None, None, :, :] - u0
+ # self.register_buffer('u_m_u0', u_m_u0, persistent=False)
+
+ # y_col = torch.arange(0, H, device="cuda") # y_col = np.arange(0, height)
+ # y = torch.transpose(torch.tile(y_col, (W, 1)), 1, 0)
+ # y = y.to(torch.float32)
+ # v_m_v0 = y[None, None, :, :] - v0
+ # self.register_buffer('v_m_v0', v_m_v0, persistent=False)
+
+ # pix_idx_mat = torch.arange(H*W, device="cuda").reshape((H, W))
+ # self.register_buffer('pix_idx_mat', pix_idx_mat, persistent=False)
+ #self.pix_idx_mat = torch.arange(height*width, device="cuda").reshape((height, width))
+
+ u0 = intrinsic[:, 0, 2][:, None, None, None]
+ v0 = intrinsic[:, 1, 2][:, None, None, None]
+ y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device="cuda"),
+ torch.arange(0, width, dtype=torch.float32, device="cuda")], indexing='ij')
+ u_m_u0 = x[None, None, :, :] - u0
+ v_m_v0 = y[None, None, :, :] - v0
+ # return u_m_u0, v_m_v0
+ self.register_buffer('v_m_v0', v_m_v0, persistent=False)
+ self.register_buffer('u_m_u0', u_m_u0, persistent=False)
+
+ def transfer_xyz(self, depth, focal_length, u_m_u0, v_m_v0):
+ x = u_m_u0 * depth / focal_length
+ y = v_m_v0 * depth / focal_length
+ z = depth
+ pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1).contiguous() # [b, h, w, c]
+ return pw
+
+ def select_index(self, B, H, W, mask):
+ """
+
+ """
+ p1 = []
+ p2 = []
+ p3 = []
+ pix_idx_mat = torch.arange(H*W, device="cuda").reshape((H, W))
+ for i in range(B):
+ inputs_index = torch.masked_select(pix_idx_mat, mask[i, ...].gt(self.eps))
+ num_effect_pixels = len(inputs_index)
+
+ intend_sample_num = int(H * W * self.sample_ratio)
+ sample_num = intend_sample_num if num_effect_pixels >= intend_sample_num else num_effect_pixels
+
+ shuffle_effect_pixels = torch.randperm(num_effect_pixels, device="cuda")
+ p1i = inputs_index[shuffle_effect_pixels[:sample_num]]
+ shuffle_effect_pixels = torch.randperm(num_effect_pixels, device="cuda")
+ p2i = inputs_index[shuffle_effect_pixels[:sample_num]]
+ shuffle_effect_pixels = torch.randperm(num_effect_pixels, device="cuda")
+ p3i = inputs_index[shuffle_effect_pixels[:sample_num]]
+
+ cat_null = torch.tensor(([0,] * (intend_sample_num - sample_num)), dtype=torch.long, device="cuda")
+ p1i = torch.cat([p1i, cat_null])
+ p2i = torch.cat([p2i, cat_null])
+ p3i = torch.cat([p3i, cat_null])
+
+ p1.append(p1i)
+ p2.append(p2i)
+ p3.append(p3i)
+
+ p1 = torch.stack(p1, dim=0)
+ p2 = torch.stack(p2, dim=0)
+ p3 = torch.stack(p3, dim=0)
+
+ p1_x = p1 % W
+ p1_y = torch.div(p1, W, rounding_mode='trunc').long() # p1 // W
+
+ p2_x = p2 % W
+ p2_y = torch.div(p2, W, rounding_mode='trunc').long() # p2 // W
+
+ p3_x = p3 % W
+ p3_y = torch.div(p3, W, rounding_mode='trunc').long() # p3 // W
+ p123 = {'p1_x': p1_x, 'p1_y': p1_y, 'p2_x': p2_x, 'p2_y': p2_y, 'p3_x': p3_x, 'p3_y': p3_y}
+ return p123
+
+ def form_pw_groups(self, p123, pw):
+ """
+ Form 3D points groups, with 3 points in each grouup.
+ :param p123: points index
+ :param pw: 3D points
+ :return:
+ """
+ B, _, _, _ = pw.shape
+ p1_x = p123['p1_x']
+ p1_y = p123['p1_y']
+ p2_x = p123['p2_x']
+ p2_y = p123['p2_y']
+ p3_x = p123['p3_x']
+ p3_y = p123['p3_y']
+
+ pw_groups = []
+ for i in range(B):
+ pw1 = pw[i, p1_y[i], p1_x[i], :]
+ pw2 = pw[i, p2_y[i], p2_x[i], :]
+ pw3 = pw[i, p3_y[i], p3_x[i], :]
+ pw_bi = torch.stack([pw1, pw2, pw3], dim=2)
+ pw_groups.append(pw_bi)
+ # [B, N, 3(x,y,z), 3(p1,p2,p3)]
+ pw_groups = torch.stack(pw_groups, dim=0)
+ return pw_groups
+
+ def filter_mask(self, p123, gt_xyz, delta_cos=0.867,
+ delta_diff_x=0.005,
+ delta_diff_y=0.005,
+ delta_diff_z=0.005):
+ pw = self.form_pw_groups(p123, gt_xyz)
+ pw12 = pw[:, :, :, 1] - pw[:, :, :, 0]
+ pw13 = pw[:, :, :, 2] - pw[:, :, :, 0]
+ pw23 = pw[:, :, :, 2] - pw[:, :, :, 1]
+ ###ignore linear
+ pw_diff = torch.cat([pw12[:, :, :, np.newaxis], pw13[:, :, :, np.newaxis], pw23[:, :, :, np.newaxis]],
+ 3) # [b, n, 3, 3]
+ m_batchsize, groups, coords, index = pw_diff.shape
+ proj_query = pw_diff.view(m_batchsize * groups, -1, index).permute(0, 2, 1).contiguous() # (B* X CX(3)) [bn, 3(p123), 3(xyz)]
+ proj_key = pw_diff.contiguous().view(m_batchsize * groups, -1, index) # B X (3)*C [bn, 3(xyz), 3(p123)]
+ q_norm = proj_query.norm(2, dim=2)
+ nm = torch.bmm(q_norm.contiguous().view(m_batchsize * groups, index, 1), q_norm.view(m_batchsize * groups, 1, index)) #[]
+ energy = torch.bmm(proj_query, proj_key) # transpose check [bn, 3(p123), 3(p123)]
+ norm_energy = energy / (nm + self.eps)
+ norm_energy = norm_energy.contiguous().view(m_batchsize * groups, -1)
+ mask_cos = torch.sum((norm_energy > delta_cos) + (norm_energy < -delta_cos), 1) > 3 # igonre
+ mask_cos = mask_cos.contiguous().view(m_batchsize, groups)
+ ##ignore padding and invilid depth
+ mask_pad = torch.sum(pw[:, :, 2, :] > self.delta_z, 2) == 3
+
+ ###ignore near
+ mask_x = torch.sum(torch.abs(pw_diff[:, :, 0, :]) < delta_diff_x, 2) > 0
+ mask_y = torch.sum(torch.abs(pw_diff[:, :, 1, :]) < delta_diff_y, 2) > 0
+ mask_z = torch.sum(torch.abs(pw_diff[:, :, 2, :]) < delta_diff_z, 2) > 0
+
+ mask_ignore = (mask_x & mask_y & mask_z) | mask_cos
+ mask_near = ~mask_ignore
+ mask = mask_pad & mask_near
+
+ return mask, pw
+
+ def select_points_groups(self, gt_depth, pred_depth, intrinsic, mask):
+ B, C, H, W = gt_depth.shape
+ focal_length = intrinsic[:, 0, 0][:, None, None, None]
+ u_m_u0, v_m_v0 = self.u_m_u0, self.v_m_v0 # self.init_image_coor(intrinsic, height=H, width=W)
+
+ pw_gt = self.transfer_xyz(gt_depth, focal_length, u_m_u0, v_m_v0)
+ pw_pred = self.transfer_xyz(pred_depth, focal_length, u_m_u0, v_m_v0)
+
+ p123 = self.select_index(B, H, W, mask)
+ # mask:[b, n], pw_groups_gt: [b, n, 3(x,y,z), 3(p1,p2,p3)]
+ mask, pw_groups_gt = self.filter_mask(p123, pw_gt,
+ delta_cos=0.867,
+ delta_diff_x=0.005,
+ delta_diff_y=0.005,
+ delta_diff_z=0.005)
+
+ # [b, n, 3, 3]
+ pw_groups_pred = self.form_pw_groups(p123, pw_pred)
+ pw_groups_pred[pw_groups_pred[:, :, 2, :] == 0] = 0.0001
+ mask_broadcast = mask.repeat(1, 9).reshape(B, 3, 3, -1).permute(0, 3, 1, 2).contiguous()
+ pw_groups_pred_not_ignore = pw_groups_pred[mask_broadcast].reshape(1, -1, 3, 3)
+ pw_groups_gt_not_ignore = pw_groups_gt[mask_broadcast].reshape(1, -1, 3, 3)
+
+ return pw_groups_gt_not_ignore, pw_groups_pred_not_ignore
+
+ def forward(self, prediction, target, mask, intrinsic, select=True, **kwargs): #gt_depth, pred_depth, select=True):
+ """
+ Virtual normal loss.
+ :param prediction: predicted depth map, [B,W,H,C]
+ :param data: target label, ground truth depth, [B, W, H, C], padding region [padding_up, padding_down]
+ :return:
+ """
+ loss = self.get_loss(prediction, target, mask, intrinsic, select, **kwargs)
+ return loss
+
+
+ def get_loss(self, prediction, target, mask, intrinsic, select=True, **kwargs):
+ # configs for the cameras
+ # focal_length = intrinsic[:, 0, 0][:, None, None, None]
+ # u0 = intrinsic[:, 0, 2][:, None, None, None]
+ # v0 = intrinsic[:, 1, 2][:, None, None, None]
+ B, _, H, W = target.shape
+ if 'u_m_u0' not in self._buffers or 'v_m_v0' not in self._buffers \
+ or self.u_m_u0.shape != torch.Size([B,1,H,W]) or self.v_m_v0.shape != torch.Size([B,1,H,W]):
+ self.init_image_coor(intrinsic, H, W)
+
+
+ gt_points, pred_points = self.select_points_groups(target, prediction, intrinsic, mask)
+
+ gt_p12 = gt_points[:, :, :, 1] - gt_points[:, :, :, 0]
+ gt_p13 = gt_points[:, :, :, 2] - gt_points[:, :, :, 0]
+ pred_p12 = pred_points[:, :, :, 1] - pred_points[:, :, :, 0]
+ pred_p13 = pred_points[:, :, :, 2] - pred_points[:, :, :, 0]
+
+ gt_normal = torch.cross(gt_p12, gt_p13, dim=2)
+ pred_normal = torch.cross(pred_p12, pred_p13, dim=2)
+ pred_norm = torch.norm(pred_normal, 2, dim=2, keepdim=True)
+ gt_norm = torch.norm(gt_normal, 2, dim=2, keepdim=True)
+ pred_mask = pred_norm == 0.0
+ gt_mask = gt_norm == 0.0
+ pred_mask = pred_mask.to(torch.float32)
+ gt_mask = gt_mask.to(torch.float32)
+ pred_mask *= self.eps
+ gt_mask *= self.eps
+ gt_norm = gt_norm + gt_mask
+ pred_norm = pred_norm + pred_mask
+ gt_normal = gt_normal / gt_norm
+ pred_normal = pred_normal / pred_norm
+ loss = torch.abs(gt_normal - pred_normal)
+ loss = torch.sum(torch.sum(loss, dim=2), dim=0)
+ if select:
+ loss, indices = torch.sort(loss, dim=0, descending=False)
+ loss = loss[int(loss.size(0) * 0.25):]
+ loss = torch.sum(loss) / (loss.numel() + self.eps)
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ loss = 0 * torch.sum(prediction)
+ print(f'VNL NAN error, {loss}')
+ return loss * self.loss_weight
+
+
+if __name__ == '__main__':
+ import cv2
+ vnl_loss = VNLoss()
+ pred_depth = np.random.random([2, 1, 480, 640])
+ gt_depth = np.zeros_like(pred_depth) #np.random.random([2, 1, 480, 640])
+ intrinsic = [[[100, 0, 200], [0, 100, 200], [0, 0, 1]], [[100, 0, 200], [0, 100, 200], [0, 0, 1]],]
+ gt_depth = torch.tensor(np.array(gt_depth, np.float32)).cuda()
+ pred_depth = torch.tensor(np.array(pred_depth, np.float32)).cuda()
+ intrinsic = torch.tensor(np.array(intrinsic, np.float32)).cuda()
+ mask = gt_depth > 0
+ loss1 = vnl_loss(pred_depth, gt_depth, mask, intrinsic)
+ loss2 = vnl_loss(pred_depth, gt_depth, mask, intrinsic)
+ print(loss1, loss2)
diff --git a/training/mono/model/losses/WCEL.py b/training/mono/model/losses/WCEL.py
new file mode 100644
index 0000000000000000000000000000000000000000..a60c5e60a0f4500b42163645e60d3554d119adfa
--- /dev/null
+++ b/training/mono/model/losses/WCEL.py
@@ -0,0 +1,157 @@
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+class WCELoss(nn.Module):
+ """
+ Weighted Cross-entropy Loss Function.
+ """
+ def __init__(self, depth_normalize, out_channel=200, loss_weight=1.0, data_type=['stereo', 'lidar'], **kwargs):
+ super(WCELoss, self).__init__()
+ self.loss_weight = loss_weight
+ self.depth_min = depth_normalize[0]
+ self.depth_max = depth_normalize[1]
+ self.bins_num = out_channel
+ self.depth_min_log = torch.log10(torch.tensor(self.depth_min))
+
+ self.alpha = 2 #0.2
+ self.config_bins()
+ self.noise_sample_ratio = 0.9 #kwargs['noise_sample_ratio'] if 'noise_sample_ratio' in kwargs else 1.0
+ self.data_type = data_type
+ self.eps = 1e-6
+
+ def config_bins(self):
+ # Modify some configs
+ self.depth_bins_interval = (torch.log10(torch.tensor(self.depth_max)) -
+ self.depth_min_log) / self.bins_num
+ bins_edges_in_log = self.depth_min_log + self.depth_bins_interval * torch.tensor(list(range(self.bins_num)) + [self.bins_num,])
+ #bins_edges_in_log = torch.from_numpy(bins_edges_in_log)
+ # The boundary of each bin
+ # bins_edges_in_log = np.array([self.depth_min_log + self.depth_bins_interval * (i + 0.5)
+ # for i in range(self.bins_num)])
+ bins_weight = torch.tensor([[np.exp(-self.alpha * (i - j) ** 2) for i in range(self.bins_num )]
+ for j in np.arange(self.bins_num )]).cuda()
+ self.register_buffer("bins_weight", bins_weight.float(), persistent=False)
+ self.register_buffer("bins_edges_in_log", bins_edges_in_log.float(), persistent=False)
+
+ def depth_to_bins_in_log(self, depth, mask):
+ """
+ Discretize depth into depth bins. Predefined bins edges are in log space.
+ Mark invalid padding area as bins_num + 1
+ Args:
+ @depth: 1-channel depth, [B, 1, h, w]
+ return: depth bins [B, C, h, w]
+ """
+ invalid_mask = ~mask
+ #depth[depth < self.depth_min] = self.depth_min
+ #depth[depth > self.depth_max] = self.depth_max
+ mask_lower = (depth <= self.depth_min)
+ mask_higher = (depth >= self.depth_max)
+ depth_bins_log = ((torch.log10(torch.abs(depth)) - self.depth_min_log) / self.depth_bins_interval).to(torch.int)
+
+ depth_bins_log[mask_lower] = 0
+ depth_bins_log[mask_higher] = self.bins_num - 1
+ depth_bins_log[depth_bins_log == self.bins_num] = self.bins_num - 1
+
+ depth_bins_log[invalid_mask] = self.bins_num + 1
+ return depth_bins_log
+
+ def depth_to_bins(self, depth, mask, depth_edges, size_limite=(300, 300)):
+ """
+ Discretize depth into depth bins. Predefined bins edges are provided.
+ Mark invalid padding area as bins_num + 1
+ Args:
+ @depth: 1-channel depth, [B, 1, h, w]
+ return: depth bins [B, C, h, w]
+ """
+ def _depth_to_bins_block_(depth, mask, depth_edges):
+ bins_id = torch.sum(depth_edges[:, None, None, None, :] < torch.abs(depth)[:, :, :, :, None], dim=-1)
+ bins_id = bins_id - 1
+ invalid_mask = ~mask
+ mask_lower = (depth <= self.depth_min)
+ mask_higher = (depth >= self.depth_max)
+
+ bins_id[mask_lower] = 0
+ bins_id[mask_higher] = self.bins_num - 1
+ bins_id[bins_id == self.bins_num] = self.bins_num - 1
+
+ bins_id[invalid_mask] = self.bins_num + 1
+ return bins_id
+ _, _, H, W = depth.shape
+ bins = mask.clone().long()
+ h_blocks = np.ceil(H / size_limite[0]).astype(np.int)
+ w_blocks = np.ceil(W/ size_limite[1]).astype(np.int)
+ for i in range(h_blocks):
+ for j in range(w_blocks):
+ h_start = i*size_limite[0]
+ h_end_proposal = (i + 1) * size_limite[0]
+ h_end = h_end_proposal if h_end_proposal < H else H
+ w_start = j*size_limite[1]
+ w_end_proposal = (j + 1) * size_limite[1]
+ w_end = w_end_proposal if w_end_proposal < W else W
+ bins_ij = _depth_to_bins_block_(
+ depth[:, :, h_start:h_end, w_start:w_end],
+ mask[:, :, h_start:h_end, w_start:w_end],
+ depth_edges
+ )
+ bins[:, :, h_start:h_end, w_start:w_end] = bins_ij
+ return bins
+
+
+ # def mask_maximum_loss(self, loss_pixels, mask):
+ # mask = mask.reshape(mask.size(0), -1)
+ # valid_pix_bt = torch.sum(mask, dim=1)
+ # mask_noise_num = (valid_pix_bt * self.noise_sample_ratio).int()
+
+ # loss_sample = []
+ # for i in range(loss_pixels.size(0)):
+ # sorted_losses, _ = torch.sort(loss_pixels[i, :][mask[i, ...]])
+ # loss_sample.append(torch.sum(sorted_losses[:mask_noise_num[i]]))
+
+ # return torch.tensor(loss_sample), mask_noise_num
+
+
+ def forward(self, prediction, target, mask=None, pred_logit=None, **kwargs): #pred_logit, gt_bins, gt):
+ B, _, H, W = target.shape
+ if 'bins_edges' not in kwargs or kwargs['bins_edges'] is None:
+ # predefined depth bins in log space
+ gt_bins = self.depth_to_bins_in_log(target, mask)
+ else:
+ bins_edges = kwargs['bins_edges']
+ gt_bins = self.depth_to_bins(target, mask, bins_edges)
+
+ classes_range = torch.arange(self.bins_num, device=gt_bins.device, dtype=gt_bins.dtype)
+ log_pred = torch.nn.functional.log_softmax(pred_logit, 1)
+ log_pred = log_pred.reshape(B, log_pred.size(1), -1).permute((0, 2, 1))
+ gt_reshape = gt_bins.reshape((B, -1))[:, :, None]
+ one_hot = (gt_reshape == classes_range).to(dtype=torch.float, device=pred_logit.device)
+ weight = torch.matmul(one_hot, self.bins_weight)
+ weight_log_pred = weight * log_pred
+ loss_pixeles = - torch.sum(weight_log_pred, dim=2)
+
+ valid_pixels = torch.sum(mask).to(dtype=torch.float, device=pred_logit.device)
+ loss = torch.sum(loss_pixeles) / (valid_pixels + self.eps)
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ raise RuntimeError(f'WCEL error, {loss}')
+ return loss * self.loss_weight
+
+
+
+if __name__ == '__main__':
+ import cv2
+ wcel = WCELoss((0.0004, 1))
+ pred_depth = np.abs(np.random.random([2, 1, 480, 640]))
+ pred_logit = np.random.random([2, 200, 480, 640])
+ gt_depth = np.random.random([2, 1, 480, 640]) - 0.5 #np.zeros_like(pred_depth) #
+ intrinsic = [[100, 100, 200, 200], [200, 200, 300, 300]]
+ gt_depth = torch.tensor(np.array(gt_depth, np.float32)).cuda()
+ pred_depth = torch.tensor(np.array(pred_depth, np.float32)).cuda()
+ intrinsic = torch.tensor(np.array(intrinsic, np.float32)).cuda()
+ pred_logit = torch.tensor(np.array(pred_logit, np.float32)).cuda()
+
+
+ mask = gt_depth > 0
+ loss1 = wcel(gt_depth, gt_depth, mask, intrinsic=intrinsic, pred_logit=pred_logit)
+ loss2 = wcel(gt_depth, gt_depth, mask, intrinsic=intrinsic, pred_logit=pred_logit)
+ print(loss1, loss2)
diff --git a/training/mono/model/losses/__init__.py b/training/mono/model/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..97df57e14af288adea69c2a2df237da7ae580787
--- /dev/null
+++ b/training/mono/model/losses/__init__.py
@@ -0,0 +1,32 @@
+from .SiLog import SilogLoss
+from .WCEL import WCELoss
+from .VNL import VNLoss
+from .Gradient import GradientLoss_Li, GradientLoss
+from .Ranking import EdgeguidedRankingLoss, RankingLoss
+from .Regularization import RegularizationLoss
+from .SSIL import SSILoss
+from .HDNL import HDNLoss
+from .HDSNL import HDSNLoss
+from .NormalRegression import EdgeguidedNormalLoss
+from .depth_to_normal import Depth2Normal
+from .photometric_loss_functions import PhotometricGeometricLoss
+from .HDSNL_random import HDSNRandomLoss
+from .HDNL_random import HDNRandomLoss
+from .AdabinsLoss import AdabinsLoss
+from .SkyRegularization import SkyRegularizationLoss
+from .PWN_Planes import PWNPlanesLoss
+from .L1 import L1Loss, L1DispLoss, L1InverseLoss
+from .ConfidenceLoss import ConfidenceLoss
+from .ScaleInvL1 import ScaleInvL1Loss
+from .NormalBranchLoss import NormalBranchLoss, DeNoConsistencyLoss
+from .GRUSequenceLoss import GRUSequenceLoss
+from .ConfidenceGuideLoss import ConfidenceGuideLoss
+from .ScaleAlignLoss import ScaleAlignLoss
+
+__all__ = [
+ 'SilogLoss', 'WCELoss', 'VNLoss', 'GradientLoss_Li', 'GradientLoss', 'EdgeguidedRankingLoss',
+ 'RankingLoss', 'RegularizationLoss', 'SSILoss', 'HDNLoss', 'HDSNLoss', 'EdgeguidedNormalLoss', 'Depth2Normal',
+ 'PhotometricGeometricLoss', 'HDSNRandomLoss', 'HDNRandomLoss', 'AdabinsLoss', 'SkyRegularizationLoss',
+ 'PWNPlanesLoss', 'L1Loss',
+ 'ConfidenceLoss', 'ScaleInvL1Loss', 'L1DispLoss', 'NormalBranchLoss', 'L1InverseLoss', 'GRUSequenceLoss', 'ConfidenceGuideLoss', 'DeNoConsistencyLoss', 'ScaleAlignLoss'
+]
diff --git a/training/mono/model/losses/depth_to_normal.py b/training/mono/model/losses/depth_to_normal.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0b1892ad35a964673bd550e73c255ca530fc2dd
--- /dev/null
+++ b/training/mono/model/losses/depth_to_normal.py
@@ -0,0 +1,302 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+class Backprojection(nn.Module):
+ """Layer to backproject a depth image given the camera intrinsics
+ Attributes
+ xy (Nx3x(HxW)): homogeneous pixel coordinates on regular grid
+ """
+ def __init__(self, height, width):
+ """
+ Args:
+ height (int): image height
+ width (int): image width
+ """
+ super(Backprojection, self).__init__()
+
+ self.height = height
+ self.width = width
+
+ # generate regular grid
+ meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
+ id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
+ id_coords = torch.tensor(id_coords, device="cuda")
+
+ # generate homogeneous pixel coordinates
+ # self.ones = nn.Parameter(torch.ones(1, 1, self.height * self.width),
+ # requires_grad=False)
+ ones = torch.ones(1, 1, self.height * self.width, device="cuda")
+ xy = torch.unsqueeze(
+ torch.stack([id_coords[0].view(-1), id_coords[1].view(-1)], 0),
+ 0
+ )
+ xy = torch.cat([xy, ones], 1)
+ #self.xy = nn.Parameter(self.xy, requires_grad=False)
+ self.register_buffer('xy', xy, persistent=False)
+ self.register_buffer('ones', ones, persistent=False)
+
+ # for virtual camera only
+ horizontal_angle_range=[195.0, -15.0]
+ vertical_angle_range=[150.0, 0.0]
+
+ horizontal_sample_num=641
+ vertical_sample_num=481
+
+ self.horizontal_angle_range = horizontal_angle_range
+ self.vertical_angle_range = vertical_angle_range
+ self.horizontal_sample_num = horizontal_sample_num
+ self.vertical_sample_num = vertical_sample_num
+
+ self.horizontal_step = (self.horizontal_angle_range[1] - self.horizontal_angle_range[0]) / (
+ self.horizontal_sample_num - 1)
+ self.vertical_step = (self.vertical_angle_range[1] - self.vertical_angle_range[0]) / (
+ self.vertical_sample_num - 1)
+
+ self.horizontal_samples = np.arange(self.horizontal_angle_range[0], self.horizontal_angle_range[1],
+ self.horizontal_step)
+ self.vertical_samples = np.arange(self.vertical_angle_range[0], self.vertical_angle_range[1],
+ self.vertical_step)
+
+ horizontal_samples_in_rad = self.horizontal_samples / 180.0 * np.pi
+ vertical_samples_in_rad = self.vertical_samples / 180.0 * np.pi
+
+ virt_H = len(self.vertical_samples)
+ virt_W = len(self.horizontal_samples)
+
+ self.virt_H, self.virt_W = virt_H, virt_W
+
+ cos_theta = np.tile(np.cos(vertical_samples_in_rad).reshape(-1, 1), (1, virt_W))
+ sin_theta = np.tile(np.sin(vertical_samples_in_rad).reshape(-1, 1), (1, virt_W))
+ cos_phi = np.tile(np.cos(horizontal_samples_in_rad).reshape(1, -1), (virt_H, 1))
+ sin_phi = np.tile(np.sin(horizontal_samples_in_rad).reshape(1, -1), (virt_H, 1))
+
+ x = (sin_theta * cos_phi).reshape(1, virt_H, virt_W)
+ y = cos_theta.reshape(1, virt_H, virt_W)
+ z = (sin_theta * sin_phi).reshape(1, virt_H, virt_W)
+
+ self.dir_in_virt_cam = np.concatenate((x, y, z), axis=0)
+ self.dir_in_virt_cam = self.dir_in_virt_cam.reshape(3, self.virt_H * self.virt_W)
+
+
+ def forward(self, depth, inv_K, img_like_out=False):
+ """
+ Args:
+ depth (Nx1xHxW): depth map
+ inv_K (Nx4x4): inverse camera intrinsics
+ img_like_out (bool): if True, the output shape is Nx4xHxW; else Nx4x(HxW)
+ Returns:
+ points (Nx4x(HxW)): 3D points in homogeneous coordinates
+ """
+ depth = depth.contiguous()
+
+ xy = self.xy.repeat(depth.shape[0], 1, 1)
+ ones = self.ones.repeat(depth.shape[0],1,1)
+
+ points = torch.matmul(inv_K[:, :3, :3], xy)
+ points = depth.view(depth.shape[0], 1, -1) * points
+ points = torch.cat([points, ones], 1)
+
+ if img_like_out:
+ points = points.reshape(depth.shape[0], 4, self.height, self.width)
+ return points
+
+
+def get_surface_normalv2(xyz, patch_size=5, mask_valid=None):
+ """
+ xyz: xyz coordinates, in [b, h, w, c]
+ patch: [p1, p2, p3,
+ p4, p5, p6,
+ p7, p8, p9]
+ surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)]
+ return: normal [h, w, 3, b]
+ """
+ b, h, w, c = xyz.shape
+ half_patch = patch_size // 2
+
+ if mask_valid == None:
+ mask_valid = xyz[:, :, :, 2] > 0 # [b, h, w]
+ mask_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1), device=mask_valid.device).bool()
+ mask_pad[:, half_patch:-half_patch, half_patch:-half_patch] = mask_valid
+
+ xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device)
+ xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz
+
+ xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] # p4
+ xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] # p6
+ xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] # p2
+ xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] # p8
+ xyz_horizon = xyz_left - xyz_right # p4p6
+ xyz_vertical = xyz_top - xyz_bottom # p2p8
+
+ xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] # p4
+ xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] # p6
+ xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] # p2
+ xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] # p8
+ xyz_horizon_in = xyz_left_in - xyz_right_in # p4p6
+ xyz_vertical_in = xyz_top_in - xyz_bottom_in # p2p8
+
+ n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3)
+ n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3)
+
+ # re-orient normals consistently
+ orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0
+ n_img_1[orient_mask] *= -1
+ orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0
+ n_img_2[orient_mask] *= -1
+
+ n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True) + 1e-4)
+ n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8)
+
+ n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True) + 1e-4)
+ n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8)
+
+ # average 2 norms
+ n_img_aver = n_img1_norm + n_img2_norm
+ n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True) + 1e-4)
+ n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8)
+ # re-orient normals consistently
+ orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0
+ n_img_aver_norm[orient_mask] *= -1
+ #n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0)) # [h, w, c, b]
+
+ # get mask for normals
+ mask_p4p6 = mask_pad[:, half_patch:half_patch + h, :w] & mask_pad[:, half_patch:half_patch + h, -w:]
+ mask_p2p8 = mask_pad[:, :h, half_patch:half_patch + w] & mask_pad[:, -h:, half_patch:half_patch + w]
+ mask_normal = mask_p2p8 & mask_p4p6
+ n_img_aver_norm[~mask_normal] = 0
+
+ # a = torch.sum(n_img1_norm_out*n_img2_norm_out, dim=2).cpu().numpy().squeeze()
+ # plt.imshow(np.abs(a), cmap='rainbow')
+ # plt.show()
+ return n_img_aver_norm.permute(0, 3, 1, 2).contiguous(), mask_normal[:, None, :, :] # [b, h, w, 3]
+
+class Depth2Normal(nn.Module):
+ """Layer to compute surface normal from depth map
+ """
+ def __init__(self,):
+ """
+ Args:
+ height (int): image height
+ width (int): image width
+ """
+ super(Depth2Normal, self).__init__()
+
+ def init_img_coor(self, height, width):
+ """
+ Args:
+ height (int): image height
+ width (int): image width
+ """
+ y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device="cuda"),
+ torch.arange(0, width, dtype=torch.float32, device="cuda")], indexing='ij')
+ meshgrid = torch.stack((x, y))
+
+ # # generate regular grid
+ # meshgrid = np.meshgrid(range(width), range(height), indexing='xy')
+ # id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
+ # id_coords = torch.tensor(id_coords)
+
+ # generate homogeneous pixel coordinates
+ ones = torch.ones((1, 1, height * width), device="cuda")
+ # xy = torch.unsqueeze(
+ # torch.stack([x.reshape(-1), y.reshape(-1)], 0),
+ # 0
+ # )
+ xy = meshgrid.reshape(2, -1).unsqueeze(0)
+ xy = torch.cat([xy, ones], 1)
+
+ self.register_buffer('xy', xy, persistent=False)
+
+ def back_projection(self, depth, inv_K, img_like_out=False, scale=1.0):
+ """
+ Args:
+ depth (Nx1xHxW): depth map
+ inv_K (Nx4x4): inverse camera intrinsics
+ img_like_out (bool): if True, the output shape is Nx4xHxW; else Nx4x(HxW)
+ Returns:
+ points (Nx4x(HxW)): 3D points in homogeneous coordinates
+ """
+ B, C, H, W = depth.shape
+ depth = depth.contiguous()
+ # xy = self.init_img_coor(height=H, width=W)
+ xy = self.xy # xy.repeat(depth.shape[0], 1, 1)
+ #ones = self.ones.repeat(depth.shape[0],1,1)
+
+ points = torch.matmul(inv_K[:, :3, :3], xy)
+ points = depth.view(depth.shape[0], 1, -1) * points
+ depth_descale = points[:, 2:3, :] / scale
+ points = torch.cat((points[:, 0:2, :], depth_descale), dim=1)
+ #points = torch.cat([points, ones], 1)
+
+ if img_like_out:
+ points = points.reshape(depth.shape[0], 3, H, W)
+ return points
+
+ # def transfer_xyz(self, u0, v0, H, W, depth, focal_length):
+ # x_row = np.arange(0, W)
+ # x = np.tile(x_row, (H, 1))
+ # x = x.astype(np.float32)
+ # x = torch.from_numpy(x.copy()).cuda()
+ # u_m_u0 = x[None, None, :, :] - u0
+ # self.register_buffer('u_m_u0', u_m_u0, persistent=False)
+
+ # y_col = np.arange(0, H) # y_col = np.arange(0, height)
+ # y = np.tile(y_col, (W, 1)).T
+ # y = y.astype(np.float32)
+ # y = torch.from_numpy(y.copy()).cuda()
+ # v_m_v0 = y[None, None, :, :] - v0
+ # self.register_buffer('v_m_v0', v_m_v0, persistent=False)
+
+ # pix_idx_mat = torch.arange(H*W).reshape((H, W)).cuda()
+ # self.register_buffer('pix_idx_mat', pix_idx_mat, persistent=False)
+
+ # x = self.u_m_u0 * depth / focal_length
+ # y = self.v_m_v0 * depth / focal_length
+ # z = depth
+ # pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c]
+ # return pw
+
+ def forward(self, depth, intrinsics, masks, scale):
+ """
+ Args:
+ depth (Nx1xHxW): depth map
+ #inv_K (Nx4x4): inverse camera intrinsics
+ intrinsics (Nx4): camera intrinsics
+ Returns:
+ normal (Nx3xHxW): normalized surface normal
+ mask (Nx1xHxW): valid mask for surface normal
+ """
+ B, C, H, W = depth.shape
+ if 'xy' not in self._buffers or self.xy.shape[-1] != H*W:
+ self.init_img_coor(height=H, width=W)
+ # Compute 3D point cloud
+ inv_K = intrinsics.inverse()
+
+ xyz = self.back_projection(depth, inv_K, scale=scale) # [N, 4, HxW]
+
+ xyz = xyz.view(depth.shape[0], 3, H, W)
+ xyz = xyz[:,:3].permute(0, 2, 3, 1).contiguous() # [b, h, w, c]
+
+ # focal_length = intrinsics[:, 0, 0][:, None, None, None]
+ # u0 = intrinsics[:, 0, 2][:, None, None, None]
+ # v0 = intrinsics[:, 1, 2][:, None, None, None]
+ # xyz2 = self.transfer_xyz(u0, v0, H, W, depth, focal_length)
+
+ normals, normal_masks = get_surface_normalv2(xyz, mask_valid=masks.squeeze())
+ normal_masks = normal_masks & masks
+ return normals, normal_masks
+
+
+
+if __name__ == '__main__':
+ d2n = Depth2Normal()
+ depth = np.random.randn(2, 1, 20, 22)
+ intrin = np.array([[300, 0, 10], [0, 300, 10], [0,0,1]])
+ intrinsics = np.stack([intrin, intrin], axis=0)
+
+ depth_t = torch.from_numpy(depth).cuda().float()
+ intrinsics = torch.from_numpy(intrinsics).cuda().float()
+ normal = d2n(depth_t, intrinsics)
+ normal2 = d2n(depth_t, intrinsics)
+ print(normal)
\ No newline at end of file
diff --git a/training/mono/model/losses/photometric_loss_functions.py b/training/mono/model/losses/photometric_loss_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f6d16723574e0d1e48ad4c3b398a3dbb8939ca1
--- /dev/null
+++ b/training/mono/model/losses/photometric_loss_functions.py
@@ -0,0 +1,300 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+import numpy as np
+
+from mono.utils.inverse_warp import inverse_warp2
+
+#device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+
+class SSIM(nn.Module):
+ """Layer to compute the SSIM loss between a pair of images
+ """
+ def __init__(self):
+ super(SSIM, self).__init__()
+ k = 7
+ self.mu_x_pool = nn.AvgPool2d(k, 1)
+ self.mu_y_pool = nn.AvgPool2d(k, 1)
+ self.sig_x_pool = nn.AvgPool2d(k, 1)
+ self.sig_y_pool = nn.AvgPool2d(k, 1)
+ self.sig_xy_pool = nn.AvgPool2d(k, 1)
+
+ self.refl = nn.ReflectionPad2d(k//2)
+
+ self.C1 = 0.01 ** 2
+ self.C2 = 0.03 ** 2
+
+ def forward(self, x, y):
+ x = self.refl(x)
+ y = self.refl(y)
+
+ mu_x = self.mu_x_pool(x)
+ mu_y = self.mu_y_pool(y)
+
+ sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2
+ sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2
+ sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
+
+ SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
+ SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)
+
+ return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
+
+
+class PhotometricGeometricLoss(nn.Module):
+ """The photometric and geometric loss between target and reference frames."""
+ def __init__(self, loss_weight=1.0, data_type=['sfm', 'stereo', 'lidar'], **kwargs):
+ super(PhotometricGeometricLoss, self).__init__()
+ self.no_min_optimize = False
+ self.no_auto_mask = False
+ self.return_dynamic_mask = True
+ self.ssim_loss = SSIM()
+ self.no_ssim = False
+ self.no_dynamic_mask = False
+ self.loss_weight_photo = 1.0
+ self.loss_weight_geometry = 0.5
+ self.total_loss_weight = loss_weight
+ self.data_type = data_type
+
+
+ def photo_and_geometry_loss(self, tgt_img, ref_imgs, tgt_depth, ref_depths, intrinsics, poses, poses_inv):
+
+ diff_img_list = []
+ diff_color_list = []
+ diff_depth_list = []
+ valid_mask_list = []
+ auto_mask_list = []
+
+ for ref_img, ref_depth, pose, pose_inv in zip(ref_imgs, ref_depths, poses, poses_inv):
+ (
+ diff_img_tmp1,
+ diff_color_tmp1,
+ diff_depth_tmp1,
+ valid_mask_tmp1,
+ auto_mask_tmp1
+ ) = self.compute_pairwise_loss(
+ tgt_img,
+ ref_img,
+ tgt_depth,
+ ref_depth,
+ pose,
+ intrinsics,
+ )
+
+ (
+ diff_img_tmp2,
+ diff_color_tmp2,
+ diff_depth_tmp2,
+ valid_mask_tmp2,
+ auto_mask_tmp2
+ ) = self.compute_pairwise_loss(
+ ref_img,
+ tgt_img,
+ ref_depth,
+ tgt_depth,
+ pose_inv,
+ intrinsics,
+ )
+
+ diff_img_list += [diff_img_tmp1, diff_img_tmp2]
+ diff_color_list += [diff_color_tmp1, diff_color_tmp2]
+ diff_depth_list += [diff_depth_tmp1, diff_depth_tmp2]
+ valid_mask_list += [valid_mask_tmp1, valid_mask_tmp2]
+ auto_mask_list += [auto_mask_tmp1, auto_mask_tmp2]
+
+ diff_img = torch.cat(diff_img_list, dim=1)
+ diff_color = torch.cat(diff_color_list, dim=1)
+ diff_depth = torch.cat(diff_depth_list, dim=1)
+ valid_mask = torch.cat(valid_mask_list, dim=1)
+ auto_mask = torch.cat(auto_mask_list, dim=1)
+
+ # using photo loss to select best match in multiple views
+ if not self.no_min_optimize:
+ indices = torch.argmin(diff_color, dim=1, keepdim=True)
+
+ diff_img = torch.gather(diff_img, 1, indices)
+ diff_depth = torch.gather(diff_depth, 1, indices)
+ valid_mask = torch.gather(valid_mask, 1, indices)
+ auto_mask = torch.gather(auto_mask, 1, indices)
+
+ if not self.no_auto_mask:
+ photo_loss = self.mean_on_mask(diff_img, valid_mask * auto_mask)
+ geometry_loss = self.mean_on_mask(diff_depth, valid_mask * auto_mask)
+ else:
+ photo_loss = self.mean_on_mask(diff_img, valid_mask)
+ geometry_loss = self.mean_on_mask(diff_depth, valid_mask)
+
+ dynamic_mask = None
+ if self.return_dynamic_mask:
+ # get dynamic mask for tgt image
+ dynamic_mask_list = []
+ for i in range(0, len(diff_depth_list), 2):
+ tmp = diff_depth_list[i]
+ tmp[valid_mask_list[1]<1] = 0
+ dynamic_mask_list += [1-tmp]
+
+ dynamic_mask = torch.cat(dynamic_mask_list, dim=1).mean(dim=1, keepdim=True)
+
+ return photo_loss, geometry_loss, dynamic_mask
+
+
+ def compute_pairwise_loss(self, tgt_img, ref_img, tgt_depth, ref_depth, pose, intrinsic):
+
+ ref_img_warped, projected_depth, computed_depth = inverse_warp2(ref_img, tgt_depth, ref_depth, pose, intrinsic, padding_mode='zeros')
+
+
+ diff_depth = (computed_depth-projected_depth).abs()/(computed_depth+projected_depth)
+
+ # masking zero values
+ valid_mask_ref = (ref_img_warped.abs().mean(dim=1, keepdim=True) > 1e-3).float()
+ valid_mask_tgt = (tgt_img.abs().mean(dim=1, keepdim=True) > 1e-3).float()
+ valid_mask = valid_mask_tgt * valid_mask_ref
+
+ diff_color = (tgt_img-ref_img_warped).abs().mean(dim=1, keepdim=True)
+ identity_warp_err = (tgt_img-ref_img).abs().mean(dim=1, keepdim=True)
+ auto_mask = (diff_color 100:
+ # mean_value = (diff * mask).sum() / mask.sum()
+ # else:
+ # mean_value = torch.tensor(0).float().to(device)
+ mean_value = (diff * mask).sum() / (mask.sum() + 1e-6)
+ return mean_value
+
+
+ def forward(self, input, ref_input, prediction, ref_prediction, intrinsic, **kwargs):
+ photo_loss, geometry_loss, dynamic_mask = self.photo_and_geometry_loss(
+ tgt_img=input,
+ ref_imgs=ref_input,
+ tgt_depth=prediction,
+ ref_depths=ref_prediction,
+ intrinsics=intrinsic,
+ poses=kwargs['pose'],
+ poses_inv=kwargs['inv_pose'])
+ loss = self.loss_weight_geometry * geometry_loss + self.loss_weight_photo * photo_loss
+ if torch.isnan(loss).item() | torch.isinf(loss).item():
+ raise RuntimeError(f'VNL error, {loss}')
+ return loss * self.total_loss_weight
+
+
+
+
+
+
+
+
+# def compute_smooth_loss(tgt_depth, tgt_img):
+# def get_smooth_loss(disp, img):
+# """
+# Computes the smoothness loss for a disparity image
+# The color image is used for edge-aware smoothness
+# """
+
+# # normalize
+# mean_disp = disp.mean(2, True).mean(3, True)
+# norm_disp = disp / (mean_disp + 1e-7)
+# disp = norm_disp
+
+# grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
+# grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])
+
+# grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
+# grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)
+
+# grad_disp_x *= torch.exp(-grad_img_x)
+# grad_disp_y *= torch.exp(-grad_img_y)
+
+# return grad_disp_x.mean() + grad_disp_y.mean()
+
+# loss = get_smooth_loss(tgt_depth, tgt_img)
+
+# return loss
+
+
+# @torch.no_grad()
+# def compute_errors(gt, pred, dataset):
+# # pred : b c h w
+# # gt: b h w
+
+# abs_diff = abs_rel = sq_rel = log10 = rmse = rmse_log = a1 = a2 = a3 = 0.0
+
+# batch_size, h, w = gt.size()
+
+# if pred.nelement() != gt.nelement():
+# pred = F.interpolate(pred, [h,w], mode='bilinear', align_corners=False)
+# # pred = F.interpolate(pred, [h,w], mode='nearest')
+
+# pred = pred.view(batch_size, h, w)
+
+# if dataset == 'kitti':
+# crop_mask = gt[0] != gt[0]
+# y1, y2 = int(0.40810811 * gt.size(1)), int(0.99189189 * gt.size(1))
+# x1, x2 = int(0.03594771 * gt.size(2)), int(0.96405229 * gt.size(2))
+# crop_mask[y1:y2, x1:x2] = 1
+# max_depth = 80
+
+# if dataset == 'cs':
+# crop_mask = gt[0] != gt[0]
+# crop_mask[256:, 192:1856] = 1
+# max_depth = 80
+
+# if dataset == 'nyu':
+# crop_mask = gt[0] != gt[0]
+# crop = np.array([45, 471, 41, 601]).astype(np.int32)
+# crop_mask[crop[0]:crop[1], crop[2]:crop[3]] = 1
+# max_depth = 10
+
+# if dataset == 'bonn':
+# crop_mask = gt[0] != gt[0]
+# crop_mask[:,:] = 1
+# max_depth = 10
+
+# if dataset == 'ddad':
+# crop_mask = gt[0] != gt[0]
+# crop_mask[:,:] = 1
+# max_depth = 200
+
+# min_depth = 1e-3
+# for current_gt, current_pred in zip(gt, pred):
+# valid = (current_gt > min_depth) & (current_gt < max_depth)
+# valid = valid & crop_mask
+
+# valid_gt = current_gt[valid]
+# valid_pred = current_pred[valid]
+
+# # align scale
+# valid_pred = valid_pred * torch.median(valid_gt)/torch.median(valid_pred)
+
+# valid_pred = valid_pred.clamp(min_depth, max_depth)
+
+# thresh = torch.max((valid_gt / valid_pred), (valid_pred / valid_gt))
+# a1 += (thresh < 1.25).float().mean()
+# a2 += (thresh < 1.25 ** 2).float().mean()
+# a3 += (thresh < 1.25 ** 3).float().mean()
+
+# diff_i = valid_gt - valid_pred
+# abs_diff += torch.mean(torch.abs(diff_i))
+# abs_rel += torch.mean(torch.abs(diff_i) / valid_gt)
+# sq_rel += torch.mean(((diff_i)**2) / valid_gt)
+# rmse += torch.sqrt(torch.mean(diff_i ** 2))
+# rmse_log += torch.sqrt(torch.mean((torch.log(valid_gt) - torch.log(valid_pred)) ** 2))
+# log10 += torch.mean(torch.abs((torch.log10(valid_gt) - torch.log10(valid_pred))))
+
+# return [metric.item() / batch_size for metric in [abs_diff, abs_rel, sq_rel, log10, rmse, rmse_log, a1, a2, a3]]
diff --git a/training/mono/model/model_pipelines/__init__.py b/training/mono/model/model_pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..33a66928649795f129f58cdf8aec685d0a5bb77d
--- /dev/null
+++ b/training/mono/model/model_pipelines/__init__.py
@@ -0,0 +1,6 @@
+from .model_pipeline import EncoderDecoder
+from .dense_pipeline import DensePredModel
+
+__all__ = [
+ 'EncoderDecoder', 'DensePredModel'
+]
diff --git a/training/mono/model/model_pipelines/dense_pipeline.py b/training/mono/model/model_pipelines/dense_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cf7d653e0f19ab30fded0d548a282c539861d7d
--- /dev/null
+++ b/training/mono/model/model_pipelines/dense_pipeline.py
@@ -0,0 +1,27 @@
+import torch
+import torch.nn as nn
+from mono.utils.comm import get_func
+
+
+class DensePredModel(nn.Module):
+ def __init__(self, cfg):
+ super(DensePredModel, self).__init__()
+
+ self.encoder = get_func('mono.model.' + cfg.model.backbone.prefix + cfg.model.backbone.type)(**cfg.model.backbone)
+ self.decoder = get_func('mono.model.' + cfg.model.decode_head.prefix + cfg.model.decode_head.type)(cfg)
+ # try:
+ # decoder_compiled = torch.compile(decoder, mode='max-autotune')
+ # "Decoder compile finished"
+ # self.decoder = decoder_compiled
+ # except:
+ # "Decoder compile failed, use default setting"
+ # self.decoder = decoder
+
+ self.training = True
+
+ def forward(self, input, **kwargs):
+ # [f_32, f_16, f_8, f_4]
+ features = self.encoder(input)
+ # [x_32, x_16, x_8, x_4, x, ...]
+ out = self.decoder(features, **kwargs)
+ return out
\ No newline at end of file
diff --git a/training/mono/model/model_pipelines/model_pipeline.py b/training/mono/model/model_pipelines/model_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e294d9505db1bf0134adb106f239d0df2a59c76
--- /dev/null
+++ b/training/mono/model/model_pipelines/model_pipeline.py
@@ -0,0 +1,34 @@
+import torch
+import torch.nn as nn
+from mono.utils.comm import get_func
+
+
+class EncoderDecoder(nn.Module):
+ def __init__(self, cfg):
+ super(EncoderDecoder, self).__init__()
+
+ self.encoder = get_func('mono.model.' + cfg.model.backbone.prefix + cfg.model.backbone.type)(**cfg.model.backbone)
+ self.decoder = get_func('mono.model.' + cfg.model.decode_head.prefix + cfg.model.decode_head.type)(cfg)
+
+ self.depth_out_head = DepthOutHead(method=cfg.model.depth_out_head.method, **cfg)
+ self.training = True
+
+ def forward(self, input, **kwargs):
+ # [f_32, f_16, f_8, f_4]
+ features = self.encoder(input)
+ # [x_32, x_16, x_8, x_4, x, ...]
+ decode_list = self.decoder(features)
+
+ pred, conf, logit, bins_edges = self.depth_out_head([decode_list[4], ])
+
+ auxi_preds = None
+ auxi_logits = None
+ out = dict(
+ prediction=pred[0],
+ confidence=conf[0],
+ pred_logit=logit[0],
+ auxi_pred=auxi_preds,
+ auxi_logit_list=auxi_logits,
+ bins_edges=bins_edges[0],
+ )
+ return out
\ No newline at end of file
diff --git a/training/mono/model/monodepth_model.py b/training/mono/model/monodepth_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ffa0acb5d20fad5df78b33c818f0c30e290b235
--- /dev/null
+++ b/training/mono/model/monodepth_model.py
@@ -0,0 +1,45 @@
+import torch
+import torch.nn as nn
+from mono.utils.comm import get_func
+from .__base_model__ import BaseDepthModel
+
+class DepthModel(BaseDepthModel):
+ def __init__(self, cfg, criterions, **kwards):
+ super(DepthModel, self).__init__(cfg, criterions)
+ model_type = cfg.model.type
+ self.training = True
+
+ # def inference(self, data):
+ # with torch.no_grad():
+ # pred_depth, _, confidence = self.inference(data)
+ # return pred_depth, confidence
+
+
+def get_monodepth_model(
+ cfg : dict,
+ criterions: dict,
+ **kwargs
+ ) -> nn.Module:
+ # config depth model
+ model = DepthModel(cfg, criterions, **kwargs)
+ #model.init_weights(load_imagenet_model, imagenet_ckpt_fpath)
+ assert isinstance(model, nn.Module)
+ return model
+
+
+def get_configured_monodepth_model(
+ cfg: dict,
+ criterions: dict,
+ ) -> nn.Module:
+ """
+ Args:
+ @ configs: configures for the network.
+ @ load_imagenet_model: whether to initialize from ImageNet-pretrained model.
+ @ imagenet_ckpt_fpath: string representing path to file with weights to initialize model with.
+ Returns:
+ # model: depth model.
+ """
+ model = get_monodepth_model(cfg, criterions)
+ return model
+
+
diff --git a/training/mono/scripts/test_scripts/test_vit.sh b/training/mono/scripts/test_scripts/test_vit.sh
new file mode 100644
index 0000000000000000000000000000000000000000..df29a45aa90b02d45be580030ab739cf11611381
--- /dev/null
+++ b/training/mono/scripts/test_scripts/test_vit.sh
@@ -0,0 +1,5 @@
+cd ../../../
+
+python mono/tools/test.py \
+ mono/configs/test_configs_vit_small/ibims.vit.dpt.raft.py \
+ --load-from vit_small_step00800000.pth
diff --git a/training/mono/scripts/train_scripts/train.sh b/training/mono/scripts/train_scripts/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..35d7552397944839584a0ca6928a12b8306c63f3
--- /dev/null
+++ b/training/mono/scripts/train_scripts/train.sh
@@ -0,0 +1,7 @@
+cd ../../../
+
+python mono/tools/train.py \
+ mono/configs/RAFTDecoder/vit.raft5.small.sanity_check.py \
+ --use-tensorboard \
+ --launcher slurm \
+ --experiment_name set1
diff --git a/training/mono/tools/test.py b/training/mono/tools/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..688a2b3d00c33ebdeefd5e4ef4530c2a3d46bd2b
--- /dev/null
+++ b/training/mono/tools/test.py
@@ -0,0 +1,165 @@
+import os
+import os.path as osp
+import time
+import sys
+CODE_SPACE=os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+sys.path.append(CODE_SPACE)
+#os.chdir(CODE_SPACE)
+import argparse
+import mmcv
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+try:
+ from mmcv.utils import Config, DictAction
+except:
+ from mmengine import Config, DictAction
+from datetime import timedelta
+import random
+import numpy as np
+
+from mono.datasets.distributed_sampler import log_canonical_transfer_info
+from mono.utils.comm import init_env
+from mono.utils.logger import setup_logger
+from mono.utils.db import load_data_info, reset_ckpt_path
+from mono.model.monodepth_model import get_configured_monodepth_model
+from mono.datasets.distributed_sampler import build_dataset_n_sampler_with_cfg
+from mono.utils.running import load_ckpt
+from mono.utils.do_test import do_test_with_dataloader, do_test_check_data
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Train a segmentor')
+ parser.add_argument('config', help='train config file path')
+ parser.add_argument('--show-dir', help='the dir to save logs and visualization results')
+ parser.add_argument(
+ '--load-from', help='the checkpoint file to load weights from')
+ parser.add_argument('--node_rank', type=int, default=0)
+ parser.add_argument('--nnodes',
+ type=int,
+ default=1,
+ help='number of nodes')
+ parser.add_argument(
+ '--options', nargs='+', action=DictAction, help='custom options')
+ parser.add_argument(
+ '--launcher', choices=['None', 'pytorch', 'slurm'], default='slurm',
+ help='job launcher')
+ args = parser.parse_args()
+ return args
+
+
+def main(args):
+ os.chdir(CODE_SPACE)
+ cfg = Config.fromfile(args.config)
+ cfg.dist_params.nnodes = args.nnodes
+ cfg.dist_params.node_rank = args.node_rank
+
+ if args.options is not None:
+ cfg.merge_from_dict(args.options)
+ # set cudnn_benchmark
+ #if cfg.get('cudnn_benchmark', False) and args.launcher != 'ror':
+ # torch.backends.cudnn.benchmark = True
+
+ # show_dir is determined in this priority: CLI > segment in file > filename
+ if args.show_dir is not None:
+ # update configs according to CLI args if args.show_dir is not None
+ cfg.show_dir = args.show_dir
+ elif cfg.get('show_dir', None) is None:
+ # use config filename + timestamp as default show_dir if cfg.show_dir is None
+ cfg.show_dir = osp.join('./show_dirs',
+ osp.splitext(osp.basename(args.config))[0],
+ args.timestamp)
+
+ # ckpt path
+ if args.load_from is None:
+ raise RuntimeError('Please set model path!')
+ cfg.load_from = args.load_from
+
+ # create show dir
+ os.makedirs(osp.abspath(cfg.show_dir), exist_ok=True)
+
+ # init the logger before other steps
+ cfg.log_file = osp.join(cfg.show_dir, f'{args.timestamp}.log')
+ logger = setup_logger(cfg.log_file)
+
+ # log some basic info
+ logger.info(f'Config:\n{cfg.pretty_text}')
+
+ # load db_info for data
+ # load data info
+ data_info = {}
+ load_data_info('data_server_info', data_info=data_info)
+ cfg.db_info = data_info
+ # update check point info
+ reset_ckpt_path(cfg.model, data_info)
+
+ # log data transfer to canonical space info
+ # log_canonical_transfer_info(cfg)
+
+ # init distributed env first, since logger depends on the dist info.
+ if args.launcher == 'none':
+ cfg.distributed = False
+ else:
+ cfg.distributed = True
+ init_env(args.launcher, cfg)
+ logger.info(f'Distributed training: {cfg.distributed}')
+
+ # dump config
+ cfg.dump(osp.join(cfg.show_dir, osp.basename(args.config)))
+
+ if not cfg.distributed:
+ main_worker(0, cfg, args.launcher)
+ else:
+ mp.spawn(main_worker, nprocs=cfg.dist_params.num_gpus_per_node, args=(cfg, args.launcher))
+
+def main_worker(local_rank: int, cfg: dict, launcher: str):
+ if cfg.distributed:
+ cfg.dist_params.global_rank = cfg.dist_params.node_rank * cfg.dist_params.num_gpus_per_node + local_rank
+ cfg.dist_params.local_rank = local_rank
+
+ torch.cuda.set_device(local_rank)
+ default_timeout = timedelta(minutes=30)
+ dist.init_process_group(backend=cfg.dist_params.backend,
+ init_method=cfg.dist_params.dist_url,
+ world_size=cfg.dist_params.world_size,
+ rank=cfg.dist_params.global_rank,
+ timeout=default_timeout,)
+
+ logger = setup_logger(cfg.log_file)
+ # build model
+ model = get_configured_monodepth_model(cfg,
+ None,
+ )
+
+ # build datasets
+ test_dataset, test_sampler = build_dataset_n_sampler_with_cfg(cfg, 'test')
+ # build data loaders
+ test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset,
+ batch_size=1,
+ num_workers=1,
+ sampler=test_sampler,
+ drop_last=False)
+
+
+ # config distributed training
+ if cfg.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model.cuda(),
+ device_ids=[local_rank],
+ output_device=local_rank,
+ find_unused_parameters=True)
+ else:
+ model = torch.nn.DataParallel(model.cuda())
+
+ # load ckpt
+ #model, _, _, _ = load_ckpt(cfg.load_from, model, strict_match=False)
+ model.eval()
+ do_test_with_dataloader(model, cfg, test_dataloader, logger=logger, is_distributed=cfg.distributed)
+ # do_test_check_data(model, cfg, test_dataloader, logger=logger, is_distributed=cfg.distributed, local_rank=local_rank)
+
+
+if __name__=='__main__':
+ # load args
+ args = parse_args()
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+ args.timestamp = timestamp
+ main(args)
diff --git a/training/mono/tools/train.py b/training/mono/tools/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..598bb79567e632dbea5dae593a8d5c74ad66668e
--- /dev/null
+++ b/training/mono/tools/train.py
@@ -0,0 +1,254 @@
+import os
+import os.path as osp
+import time
+import sys
+CODE_SPACE=os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+sys.path.append(CODE_SPACE)
+#os.chdir(CODE_SPACE)
+import argparse
+import copy
+import mmcv
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+try:
+ from mmcv.utils import Config, DictAction
+except:
+ from mmengine import Config, DictAction
+import socket
+import subprocess
+from datetime import timedelta
+import random
+import numpy as np
+import logging
+
+from mono.datasets.distributed_sampler import log_canonical_transfer_info
+from mono.utils.comm import init_env, collect_env
+from mono.utils.logger import setup_logger
+from mono.utils.db import load_data_info, reset_ckpt_path
+from mono.utils.do_train import do_train
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Train a segmentor')
+ parser.add_argument('config', help='train config file path')
+ parser.add_argument('--work-dir', help='the dir to save logs and models')
+ parser.add_argument('--tensorboard-dir', help='the dir to save tensorboard logs')
+ parser.add_argument(
+ '--load-from', help='the checkpoint file to load weights from')
+ parser.add_argument(
+ '--resume-from', help='the checkpoint file to resume from')
+ parser.add_argument(
+ '--no-validate',
+ action='store_true',
+ help='whether not to evaluate the checkpoint during training')
+ parser.add_argument(
+ '--gpu-ids',
+ type=int,
+ nargs='+',
+ help='ids of gpus to use '
+ '(only applicable to non-distributed training)')
+ parser.add_argument('--seed', type=int, default=88, help='random seed')
+ parser.add_argument(
+ '--deterministic',
+ action='store_true',
+ help='whether to set deterministic options for CUDNN backend.')
+ parser.add_argument(
+ '--use-tensorboard',
+ action='store_true',
+ help='whether to set deterministic options for CUDNN backend.')
+ parser.add_argument(
+ '--options', nargs='+', action=DictAction, help='custom options')
+ parser.add_argument('--node_rank', type=int, default=0)
+ parser.add_argument('--nnodes',
+ type=int,
+ default=1,
+ help='number of nodes')
+ parser.add_argument(
+ '--launcher', choices=['None', 'pytorch', 'slurm', 'mpi', 'ror'], default='slurm',
+ help='job launcher')
+ parser.add_argument('--local_rank',
+ type=int,
+ default=0,
+ help='rank')
+ parser.add_argument('--experiment_name', default='debug', help='the experiment name for mlflow')
+ args = parser.parse_args()
+ return args
+
+
+def set_random_seed(seed, deterministic=False):
+ """Set random seed.
+ Args:
+ @seed (int): Seed to be used.
+ @deterministic (bool): Whether to set the deterministic option for
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+ to True and `torch.backends.cudnn.benchmark` to False.
+ Default: False.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ #if deterministic:
+ # torch.backends.cudnn.deterministic = True
+ # torch.backends.cudnn.benchmark = False
+
+def main(args):
+ os.chdir(CODE_SPACE)
+ cfg = Config.fromfile(args.config)
+ cfg.dist_params.nnodes = args.nnodes
+ cfg.dist_params.node_rank = args.node_rank
+ cfg.deterministic = args.deterministic
+ if args.options is not None:
+ cfg.merge_from_dict(args.options)
+ # set cudnn_benchmark
+ #if cfg.get('cudnn_benchmark', False) and args.launcher != 'ror':
+ # torch.backends.cudnn.benchmark = True
+ # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
+ # in PyTorch 1.12 and later.
+ # torch.backends.cuda.matmul.allow_tf32 = False
+ # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
+ # torch.backends.cudnn.allow_tf32 = False
+
+ # work_dir is determined in this priority: CLI > segment in file > filename
+ if args.work_dir is not None:
+ # update configs according to CLI args if args.work_dir is not None
+ cfg.work_dir = args.work_dir
+ elif cfg.get('work_dir', None) is None:
+ # use config filename + timestamp as default work_dir if cfg.work_dir is None
+ cfg.work_dir = osp.join('./work_dirs',
+ osp.splitext(osp.basename(args.config))[0],
+ args.timestamp)
+ # tensorboard_dir is determined in this priority: CLI > segment in file > filename
+ if args.tensorboard_dir is not None:
+ cfg.tensorboard_dir = args.tensorboard_dir
+ elif cfg.get('tensorboard_dir', None) is None:
+ # use cfg.work_dir + 'tensorboard' as default tensorboard_dir if cfg.tensorboard_dir is None
+ cfg.tensorboard_dir = osp.join(cfg.work_dir, 'tensorboard')
+
+ # ckpt path
+ if args.load_from is not None:
+ cfg.load_from = args.load_from
+ # resume training
+ if args.resume_from is not None:
+ cfg.resume_from = args.resume_from
+
+ # create work_dir and tensorboard_dir
+ os.makedirs(osp.abspath(cfg.work_dir), exist_ok=True)
+ os.makedirs(os.path.abspath(cfg.tensorboard_dir), exist_ok=True)
+
+ # init the logger before other steps
+ cfg.log_file = osp.join(cfg.work_dir, f'{args.timestamp}.log')
+ logger = setup_logger(cfg.log_file)
+
+ # init the meta dict to record some important information such as
+ # environment info and seed, which will be logged
+ meta = dict()
+ # log env info
+ env_info_dict = collect_env()
+ env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
+ dash_line = '-' * 60 + '\n'
+ logger.info('Environment info:\n' + dash_line + env_info + '\n' +
+ dash_line)
+ meta['env_info'] = env_info
+
+ # log some basic info
+ # logger.info(f'Config:\n{cfg.pretty_text}')
+
+ # mute online evaluation
+ if args.no_validate:
+ cfg.evaluation.online_eval = False
+
+
+ cfg.seed = args.seed
+ meta['seed'] = args.seed
+ meta['exp_name'] = osp.basename(args.config)
+
+ # load data info
+ data_info = {}
+ load_data_info('data_server_info', data_info=data_info)
+ cfg.db_info = data_info
+ # update check point info
+ reset_ckpt_path(cfg.model, data_info)
+
+ # log data transfer to canonical space info``
+ # log_canonical_transfer_info(cfg)
+
+ # init distributed env first, since logger depends on the dist info.
+ if args.launcher == 'None':
+ cfg.distributed = False
+ else:
+ cfg.distributed = True
+ init_env(args.launcher, cfg)
+ logger.info(f'Distributed training: {cfg.distributed}')
+ logger.info(cfg.dist_params)
+ # dump config
+ cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
+
+ cfg.experiment_name = args.experiment_name
+
+ if not cfg.distributed:
+ main_worker(0, cfg)
+ else:
+ # distributed training
+ if args.launcher == 'slurm':
+ mp.spawn(main_worker, nprocs=cfg.dist_params.num_gpus_per_node, args=(cfg, args.launcher))
+ elif args.launcher == 'pytorch':
+ main_worker(args.local_rank, cfg, args.launcher)
+
+def main_worker(local_rank: int, cfg: dict, launcher: str='slurm'):
+ logger = setup_logger(cfg.log_file)
+ if cfg.distributed:
+ if launcher == 'slurm':
+ torch.set_num_threads(8) # without it, the spawn method is much slower than the launch method
+ cfg.dist_params.global_rank = cfg.dist_params.node_rank * cfg.dist_params.num_gpus_per_node + local_rank
+ cfg.dist_params.local_rank = local_rank
+ os.environ['RANK']=str(cfg.dist_params.global_rank)
+ else:
+ torch.set_num_threads(1)
+
+ torch.cuda.set_device(local_rank)
+ default_timeout = timedelta(minutes=10)
+ dist.init_process_group(
+ backend=cfg.dist_params.backend,
+ init_method=cfg.dist_params.dist_url,
+ world_size=cfg.dist_params.world_size,
+ rank=cfg.dist_params.global_rank,)
+ #timeout=default_timeout,)
+ dist.barrier()
+
+ # if cfg.distributed:
+
+ # cfg.dist_params.global_rank = cfg.dist_params.node_rank * cfg.dist_params.num_gpus_per_node + local_rank
+ # cfg.dist_params.local_rank = local_rank
+ # os.environ['RANK']=str(cfg.dist_params.global_rank)
+
+ # if launcher == 'ror':
+ # init_torch_process_group(use_hvd=False)
+ # else:
+ # #torch.set_num_threads(4) # without it, the spawn method maybe much slower than the launch method
+ # torch.cuda.set_device(local_rank)
+ # default_timeout = timedelta(minutes=30)
+ # dist.init_process_group(
+ # backend=cfg.dist_params.backend,
+ # init_method=cfg.dist_params.dist_url,
+ # world_size=cfg.dist_params.world_size,
+ # rank=cfg.dist_params.global_rank,)
+ # #timeout=default_timeout,)
+
+ # set random seeds
+ if cfg.seed is not None:
+ logger.info(f'Set random seed to {cfg.seed}, deterministic: 'f'{cfg.deterministic}')
+ set_random_seed(cfg.seed, deterministic=cfg.deterministic)
+ # with torch.autograd.set_detect_anomaly(True):
+ do_train(local_rank, cfg)
+
+
+if __name__=='__main__':
+ # load args
+ args = parse_args()
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+ args.timestamp = timestamp
+ print(args.work_dir, args.tensorboard_dir)
+ main(args)
\ No newline at end of file
diff --git a/training/mono/utils/__init__.py b/training/mono/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/training/mono/utils/__init__.py
@@ -0,0 +1 @@
+
diff --git a/training/mono/utils/avg_meter.py b/training/mono/utils/avg_meter.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7321dd4e222dd0d02d84cc5aa31bdeab24007be
--- /dev/null
+++ b/training/mono/utils/avg_meter.py
@@ -0,0 +1,561 @@
+import numpy as np
+import torch
+import torch.distributed as dist
+from .inverse_warp import pixel2cam, cam2pixel2
+import torch.nn.functional as F
+import matplotlib.pyplot as plt
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+ def __init__(self) -> None:
+ self.reset()
+
+ def reset(self) -> None:
+ self.val = np.longdouble(0.0)
+ self.avg = np.longdouble(0.0)
+ self.sum = np.longdouble(0.0)
+ self.count = np.longdouble(0.0)
+
+ def update(self, val, n: float = 1) -> None:
+ self.val = val
+ self.sum += val
+ self.count += n
+ self.avg = self.sum / (self.count + 1e-6)
+
+class MetricAverageMeter(AverageMeter):
+ """
+ An AverageMeter designed specifically for evaluating segmentation results.
+ """
+ def __init__(self, metrics: list) -> None:
+ """ Initialize object. """
+ # average meters for metrics
+ self.abs_rel = AverageMeter()
+ self.rmse = AverageMeter()
+ self.silog = AverageMeter()
+ self.delta1 = AverageMeter()
+ self.delta2 = AverageMeter()
+ self.delta3 = AverageMeter()
+
+ self.metrics = metrics
+
+ self.consistency = AverageMeter()
+ self.log10 = AverageMeter()
+ self.rmse_log = AverageMeter()
+ self.sq_rel = AverageMeter()
+
+ # normal
+ self.normal_mean = AverageMeter()
+ self.normal_rmse = AverageMeter()
+ self.normal_a1 = AverageMeter()
+ self.normal_a2 = AverageMeter()
+
+ self.normal_median = AverageMeter()
+ self.normal_a3 = AverageMeter()
+ self.normal_a4 = AverageMeter()
+ self.normal_a5 = AverageMeter()
+
+
+ def update_metrics_cpu(self,
+ pred: torch.Tensor,
+ target: torch.Tensor,
+ mask: torch.Tensor,):
+ """
+ Update metrics on cpu
+ """
+
+ assert pred.shape == target.shape
+
+ if len(pred.shape) == 3:
+ pred = pred[:, None, :, :]
+ target = target[:, None, :, :]
+ mask = mask[:, None, :, :]
+ elif len(pred.shape) == 2:
+ pred = pred[None, None, :, :]
+ target = target[None, None, :, :]
+ mask = mask[None, None, :, :]
+
+
+ # Absolute relative error
+ abs_rel_sum, valid_pics = get_absrel_err(pred, target, mask)
+ abs_rel_sum = abs_rel_sum.numpy()
+ valid_pics = valid_pics.numpy()
+ self.abs_rel.update(abs_rel_sum, valid_pics)
+
+ # squared relative error
+ sqrel_sum, _ = get_sqrel_err(pred, target, mask)
+ sqrel_sum = sqrel_sum.numpy()
+ self.sq_rel.update(sqrel_sum, valid_pics)
+
+ # root mean squared error
+ rmse_sum, _ = get_rmse_err(pred, target, mask)
+ rmse_sum = rmse_sum.numpy()
+ self.rmse.update(rmse_sum, valid_pics)
+
+ # log root mean squared error
+ log_rmse_sum, _ = get_rmse_log_err(pred, target, mask)
+ log_rmse_sum = log_rmse_sum.numpy()
+ self.rmse.update(log_rmse_sum, valid_pics)
+
+ # log10 error
+ log10_sum, _ = get_log10_err(pred, target, mask)
+ log10_sum = log10_sum.numpy()
+ self.rmse.update(log10_sum, valid_pics)
+
+ # scale-invariant root mean squared error in log space
+ silog_sum, _ = get_silog_err(pred, target, mask)
+ silog_sum = silog_sum.numpy()
+ self.silog.update(silog_sum, valid_pics)
+
+ # ratio error, delta1, ....
+ delta1_sum, delta2_sum, delta3_sum, _ = get_ratio_error(pred, target, mask)
+ delta1_sum = delta1_sum.numpy()
+ delta2_sum = delta2_sum.numpy()
+ delta3_sum = delta3_sum.numpy()
+
+ self.delta1.update(delta1_sum, valid_pics)
+ self.delta2.update(delta1_sum, valid_pics)
+ self.delta3.update(delta1_sum, valid_pics)
+
+
+ def update_metrics_gpu(
+ self,
+ pred: torch.Tensor,
+ target: torch.Tensor,
+ mask: torch.Tensor,
+ is_distributed: bool,
+ pred_next: torch.tensor = None,
+ pose_f1_to_f2: torch.tensor = None,
+ intrinsic: torch.tensor = None):
+ """
+ Update metric on GPU. It supports distributed processing. If multiple machines are employed, please
+ set 'is_distributed' as True.
+ """
+ assert pred.shape == target.shape
+
+ if len(pred.shape) == 3:
+ pred = pred[:, None, :, :]
+ target = target[:, None, :, :]
+ mask = mask[:, None, :, :]
+ elif len(pred.shape) == 2:
+ pred = pred[None, None, :, :]
+ target = target[None, None, :, :]
+ mask = mask[None, None, :, :]
+
+
+ # Absolute relative error
+ abs_rel_sum, valid_pics = get_absrel_err(pred, target, mask)
+ if is_distributed:
+ dist.all_reduce(abs_rel_sum), dist.all_reduce(valid_pics)
+ abs_rel_sum = abs_rel_sum.cpu().numpy()
+ valid_pics = int(valid_pics)
+ self.abs_rel.update(abs_rel_sum, valid_pics)
+
+ # root mean squared error
+ rmse_sum, _ = get_rmse_err(pred, target, mask)
+ if is_distributed:
+ dist.all_reduce(rmse_sum)
+ rmse_sum = rmse_sum.cpu().numpy()
+ self.rmse.update(rmse_sum, valid_pics)
+
+ # log root mean squared error
+ log_rmse_sum, _ = get_rmse_log_err(pred, target, mask)
+ if is_distributed:
+ dist.all_reduce(log_rmse_sum)
+ log_rmse_sum = log_rmse_sum.cpu().numpy()
+ self.rmse_log.update(log_rmse_sum, valid_pics)
+
+ # log10 error
+ log10_sum, _ = get_log10_err(pred, target, mask)
+ if is_distributed:
+ dist.all_reduce(log10_sum)
+ log10_sum = log10_sum.cpu().numpy()
+ self.log10.update(log10_sum, valid_pics)
+
+ # scale-invariant root mean squared error in log space
+ silog_sum, _ = get_silog_err(pred, target, mask)
+ if is_distributed:
+ dist.all_reduce(silog_sum)
+ silog_sum = silog_sum.cpu().numpy()
+ self.silog.update(silog_sum, valid_pics)
+
+ # ratio error, delta1, ....
+ delta1_sum, delta2_sum, delta3_sum, _ = get_ratio_error(pred, target, mask)
+ if is_distributed:
+ dist.all_reduce(delta1_sum), dist.all_reduce(delta2_sum), dist.all_reduce(delta3_sum)
+ delta1_sum = delta1_sum.cpu().numpy()
+ delta2_sum = delta2_sum.cpu().numpy()
+ delta3_sum = delta3_sum.cpu().numpy()
+
+ self.delta1.update(delta1_sum, valid_pics)
+ self.delta2.update(delta2_sum, valid_pics)
+ self.delta3.update(delta3_sum, valid_pics)
+
+ # video consistency error
+ consistency_rel_sum, valid_warps = get_video_consistency_err(pred, pred_next, pose_f1_to_f2, intrinsic)
+ if is_distributed:
+ dist.all_reduce(consistency_rel_sum), dist.all_reduce(valid_warps)
+ consistency_rel_sum = consistency_rel_sum.cpu().numpy()
+ valid_warps = int(valid_warps)
+ self.consistency.update(consistency_rel_sum, valid_warps)
+
+ ## for surface normal
+ def update_normal_metrics_gpu(
+ self,
+ pred: torch.Tensor, # (B, 3, H, W)
+ target: torch.Tensor, # (B, 3, H, W)
+ mask: torch.Tensor, # (B, 1, H, W)
+ is_distributed: bool,
+ ):
+ """
+ Update metric on GPU. It supports distributed processing. If multiple machines are employed, please
+ set 'is_distributed' as True.
+ """
+ assert pred.shape == target.shape
+
+ valid_pics = torch.sum(mask, dtype=torch.float32) + 1e-6
+
+ if valid_pics < 10:
+ return
+
+ mean_error = rmse_error = a1_error = a2_error = dist_node_cnt = valid_pics
+ normal_error = torch.cosine_similarity(pred, target, dim=1)
+ normal_error = torch.clamp(normal_error, min=-1.0, max=1.0)
+ angle_error = torch.acos(normal_error) * 180.0 / torch.pi
+ angle_error = angle_error[:, None, :, :]
+ angle_error = angle_error[mask]
+ # Calculation error
+ mean_error = angle_error.sum() / valid_pics
+ rmse_error = torch.sqrt( torch.sum(torch.square(angle_error)) / valid_pics )
+ median_error = angle_error.median()
+ a1_error = 100.0 * (torch.sum(angle_error < 5) / valid_pics)
+ a2_error = 100.0 * (torch.sum(angle_error < 7.5) / valid_pics)
+
+ a3_error = 100.0 * (torch.sum(angle_error < 11.25) / valid_pics)
+ a4_error = 100.0 * (torch.sum(angle_error < 22.5) / valid_pics)
+ a5_error = 100.0 * (torch.sum(angle_error < 30) / valid_pics)
+
+ # if valid_pics > 1e-5:
+ # If the current node gets data with valid normal
+ dist_node_cnt = (valid_pics - 1e-6) / valid_pics
+
+ if is_distributed:
+ dist.all_reduce(dist_node_cnt)
+ dist.all_reduce(mean_error)
+ dist.all_reduce(rmse_error)
+ dist.all_reduce(a1_error)
+ dist.all_reduce(a2_error)
+
+ dist.all_reduce(a3_error)
+ dist.all_reduce(a4_error)
+ dist.all_reduce(a5_error)
+
+ dist_node_cnt = dist_node_cnt.cpu().numpy()
+ self.normal_mean.update(mean_error.cpu().numpy(), dist_node_cnt)
+ self.normal_rmse.update(rmse_error.cpu().numpy(), dist_node_cnt)
+ self.normal_a1.update(a1_error.cpu().numpy(), dist_node_cnt)
+ self.normal_a2.update(a2_error.cpu().numpy(), dist_node_cnt)
+
+ self.normal_median.update(median_error.cpu().numpy(), dist_node_cnt)
+ self.normal_a3.update(a3_error.cpu().numpy(), dist_node_cnt)
+ self.normal_a4.update(a4_error.cpu().numpy(), dist_node_cnt)
+ self.normal_a5.update(a5_error.cpu().numpy(), dist_node_cnt)
+
+
+ def get_metrics(self,):
+ """
+ """
+ metrics_dict = {}
+ for metric in self.metrics:
+ metrics_dict[metric] = self.__getattribute__(metric).avg
+ return metrics_dict
+
+
+ def get_metrics(self,):
+ """
+ """
+ metrics_dict = {}
+ for metric in self.metrics:
+ metrics_dict[metric] = self.__getattribute__(metric).avg
+ return metrics_dict
+
+
+def get_absrel_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor):
+ """
+ Computes absolute relative error.
+ Takes preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred * mask
+
+ #Mean Absolute Relative Error
+ rel = torch.abs(t_m - p_m) / (t_m + 1e-10) # compute errors
+ abs_rel_sum = torch.sum(rel.reshape((b, c, -1)), dim=2) # [b, c]
+ num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c]
+ abs_err = abs_rel_sum / (num + 1e-10)
+ valid_pics = torch.sum(num > 0)
+ return torch.sum(abs_err), valid_pics
+
+def get_sqrel_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor):
+ """
+ Computes squared relative error.
+ Takes preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred * mask
+
+ #Mean Absolute Relative Error
+ sq_rel = torch.abs(t_m - p_m)**2 / (t_m + 1e-10) # compute errors
+ sq_rel_sum = torch.sum(sq_rel.reshape((b, c, -1)), dim=2) # [b, c]
+ num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c]
+ sqrel_err = sq_rel_sum / (num + 1e-10)
+ valid_pics = torch.sum(num > 0)
+ return torch.sum(sqrel_err), valid_pics
+
+def get_log10_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor):
+ """
+ Computes log10 error.
+ Takes preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred * mask
+
+ diff_log = (torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)) * mask
+ log10_diff = torch.abs(diff_log) # compute errors
+ log10_sum = torch.sum(log10_diff.reshape((b, c, -1)), dim=2) # [b, c]
+ num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c]
+ abs_err = log10_sum / (num + 1e-10)
+ valid_pics = torch.sum(num > 0)
+ return torch.sum(abs_err), valid_pics
+
+def get_rmse_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor):
+ """
+ Computes log root mean squared error.
+ Takes preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred * mask
+
+ square = (t_m - p_m) ** 2
+ rmse_sum = torch.sum(square.reshape((b, c, -1)), dim=2) # [b, c]
+ num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c]
+ rmse = torch.sqrt(rmse_sum / (num + 1e-10))
+ valid_pics = torch.sum(num > 0)
+ return torch.sum(rmse), valid_pics
+
+def get_rmse_log_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor):
+ """
+ Computes root mean squared error.
+ Takes preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred * mask
+
+ diff_log = (torch.log(p_m+1e-10) - torch.log(t_m+1e-10)) * mask
+ square = diff_log ** 2
+ rmse_sum = torch.sum(square.reshape((b, c, -1)), dim=2) # [b, c]
+ num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c]
+ rmse = torch.sqrt(rmse_sum / (num + 1e-10))
+ valid_pics = torch.sum(num > 0)
+ return torch.sum(rmse), valid_pics
+
+
+def get_silog_err(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor):
+ """
+ Computes scale invariant loss based on differences of logs of depth maps.
+ Takes preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred * mask
+
+ diff_log = (torch.log(p_m+1e-10) - torch.log(t_m+1e-10)) * mask
+ diff_log_sum = torch.sum(diff_log.reshape((b, c, -1)), dim=2) # [b, c]
+ diff_log_square = diff_log ** 2
+ diff_log_square_sum = torch.sum(diff_log_square.reshape((b, c, -1)), dim=2) # [b, c]
+ num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c]
+ silog = torch.sqrt(diff_log_square_sum / (num + 1e-10) - (diff_log_sum / (num + 1e-10)) **2 )
+ valid_pics = torch.sum(num > 0)
+ if torch.isnan(torch.sum(silog)):
+ print('None in silog')
+ return torch.sum(silog), valid_pics
+
+
+def get_ratio_error(pred: torch.tensor,
+ target: torch.tensor,
+ mask: torch.tensor):
+ """
+ Computes the percentage of pixels for which the ratio of the two depth maps is less than a given threshold.
+ Takes preprocessed depths (no nans, infs and non-positive values).
+ pred, target, and mask should be in the shape of [b, c, h, w]
+ """
+ assert len(pred.shape) == 4, len(target.shape) == 4
+ b, c, h, w = pred.shape
+ mask = mask.to(torch.float)
+ t_m = target * mask
+ p_m = pred
+
+ gt_pred = t_m / (p_m + 1e-10)
+ pred_gt = p_m / (t_m + 1e-10)
+ gt_pred = gt_pred.reshape((b, c, -1))
+ pred_gt = pred_gt.reshape((b, c, -1))
+ gt_pred_gt = torch.cat((gt_pred, pred_gt), axis=1)
+ ratio_max = torch.amax(gt_pred_gt, axis=1)
+
+ mask = mask.reshape((b, -1))
+ delta_1_sum = torch.sum((ratio_max < 1.25) * mask, dim=1) # [b, ]
+ delta_2_sum = torch.sum((ratio_max < 1.25**2) * mask, dim=1) # [b,]
+ delta_3_sum = torch.sum((ratio_max < 1.25**3) * mask, dim=1) # [b, ]
+ num = torch.sum(mask, dim=1) # [b, ]
+
+ delta_1 = delta_1_sum / (num + 1e-10)
+ delta_2 = delta_2_sum / (num + 1e-10)
+ delta_3 = delta_3_sum / (num + 1e-10)
+ valid_pics = torch.sum(num > 0)
+
+ return torch.sum(delta_1), torch.sum(delta_2), torch.sum(delta_3), valid_pics
+
+def unproj_pcd(
+ depth: torch.tensor,
+ intrinsic: torch.tensor
+ ):
+ depth = depth.squeeze(1) # [B, H, W]
+ b, h, w = depth.size()
+ v = torch.arange(0, h).view(1, h, 1).expand(b, h, w).type_as(depth) # [B, H, W]
+ u = torch.arange(0, w).view(1, 1, w).expand(b, h, w).type_as(depth) # [B, H, W]
+ x = (u - intrinsic[:, 0, 2]) / intrinsic[:, 0, 0] * depth # [B, H, W]
+ y = (v - intrinsic[:, 1, 2]) / intrinsic[:, 0, 0] * depth # [B, H, W]
+ pcd = torch.stack([x, y, depth], dim=1)
+ return pcd
+
+def forward_warp(
+ depth: torch.tensor,
+ intrinsic: torch.tensor,
+ pose: torch.tensor,
+ ):
+ """
+ Warp the depth with the provided pose.
+ Args:
+ depth: depth map of the target image -- [B, 1, H, W]
+ intrinsic: camera intrinsic parameters -- [B, 3, 3]
+ pose: the camera pose -- [B, 4, 4]
+ """
+ B, _, H, W = depth.shape
+ pcd = unproj_pcd(depth.float(), intrinsic.float())
+ pcd = pcd.reshape(B, 3, -1) # [B, 3, H*W]
+ rot, tr = pose[:, :3, :3], pose[:, :3, -1:]
+ proj_pcd = rot @ pcd + tr
+
+ img_coors = intrinsic @ proj_pcd
+
+ X = img_coors[:, 0, :]
+ Y = img_coors[:, 1, :]
+ Z = img_coors[:, 2, :].clamp(min=1e-3)
+
+ x_img_coor = (X/Z + 0.5).long()
+ y_img_coor = (Y/Z + 0.5).long()
+
+ X_mask = ((x_img_coor >=0) & (x_img_coor < W))
+ Y_mask = ((y_img_coor >=0) & (y_img_coor < H))
+ mask = X_mask & Y_mask
+
+ proj_depth = torch.zeros_like(Z).reshape(B, 1, H, W)
+ for i in range(B):
+ proj_depth[i, :, y_img_coor[i,...][mask[i,...]], x_img_coor[i,...][mask[i,...]]] = Z[i,...][mask[i,...]]
+ plt.imsave('warp2.png', proj_depth.squeeze().cpu().numpy(), cmap='rainbow')
+ return proj_depth
+
+
+def get_video_consistency_err(
+ pred_f1: torch.tensor,
+ pred_f2: torch.tensor,
+ ego_pose_f1_to_f2: torch.tensor,
+ intrinsic: torch.tensor,
+ ):
+ """
+ Compute consistency error between consecutive frames.
+ """
+ if pred_f2 is None or ego_pose_f1_to_f2 is None or intrinsic is None:
+ return torch.zeros_like(pred_f1).sum(), torch.zeros_like(pred_f1).sum()
+ ego_pose_f1_to_f2 = ego_pose_f1_to_f2.float()
+ pred_f2 = pred_f2.float()
+
+ pred_f1 = pred_f1[:, None, :, :] if pred_f1.ndim == 3 else pred_f1
+ pred_f2 = pred_f2[:, None, :, :] if pred_f2.ndim == 3 else pred_f2
+ pred_f1 = pred_f1[None, None, :, :] if pred_f1.ndim == 2 else pred_f1
+ pred_f2 = pred_f2[None, None, :, :] if pred_f2.ndim == 2 else pred_f2
+
+ B, _, H, W = pred_f1.shape
+ # Get projection matrix for tgt camera frame to source pixel frame
+ cam_coords = pixel2cam(pred_f1.squeeze(1).float(), intrinsic.inverse().float()) # [B,3,H,W]
+ #proj_depth_my = forward_warp(pred_f1, intrinsic, ego_pose_f1_to_f2)
+
+ proj_f1_to_f2 = intrinsic @ ego_pose_f1_to_f2[:, :3, :] # [B, 3, 4]
+ rot, tr = proj_f1_to_f2[:, :, :3], proj_f1_to_f2[:, :, -1:]
+ f2_pixel_coords, warped_depth_f1_to_f2 = cam2pixel2(cam_coords, rot, tr, padding_mode="zeros") # [B,H,W,2]
+
+ projected_depth = F.grid_sample(pred_f2, f2_pixel_coords, padding_mode="zeros", align_corners=False)
+
+ mask_valid = (projected_depth > 1e-6) & (warped_depth_f1_to_f2 > 1e-6)
+
+ # plt.imsave('f1.png', pred_f1.squeeze().cpu().numpy(), cmap='rainbow')
+ # plt.imsave('f2.png', pred_f2.squeeze().cpu().numpy(), cmap='rainbow')
+ # plt.imsave('warp.png', warped_depth_f1_to_f2.squeeze().cpu().numpy(), cmap='rainbow')
+ # plt.imsave('proj.png', projected_depth.squeeze().cpu().numpy(), cmap='rainbow')
+
+ consistency_rel_err, valid_pix = get_absrel_err(warped_depth_f1_to_f2, projected_depth, mask_valid)
+ return consistency_rel_err, valid_pix
+
+
+if __name__ == '__main__':
+ cfg = ['abs_rel', 'delta1']
+ dam = MetricAverageMeter(cfg)
+
+ pred_depth = np.random.random([2, 480, 640])
+ gt_depth = np.random.random([2, 480, 640]) - 0.5 #np.ones_like(pred_depth) * (-1) #
+ intrinsic = [[100, 100, 200, 200], [200, 200, 300, 300]]
+
+ pred = torch.from_numpy(pred_depth).cuda()
+ gt = torch.from_numpy(gt_depth).cuda()
+
+ mask = gt > 0
+ dam.update_metrics_gpu(pred, pred, mask, False)
+ eval_error = dam.get_metrics()
+ print(eval_error)
diff --git a/training/mono/utils/comm.py b/training/mono/utils/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..11227f5c569c0839f9c1239a046b389a29272a65
--- /dev/null
+++ b/training/mono/utils/comm.py
@@ -0,0 +1,343 @@
+import importlib
+import torch
+import torch.distributed as dist
+from .avg_meter import AverageMeter
+from collections import defaultdict, OrderedDict
+import os
+import socket
+from mmcv.utils import collect_env as collect_base_env
+try:
+ from mmcv.utils import get_git_hash
+except:
+ from mmengine import get_git_hash
+#import mono.mmseg as mmseg
+import mmseg
+import time
+import datetime
+import logging
+
+
+def main_process() -> bool:
+ return get_rank() == 0
+ #return not cfg.distributed or \
+ # (cfg.distributed and cfg.local_rank == 0)
+
+def get_world_size() -> int:
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+def get_rank() -> int:
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+def _find_free_port():
+ # refer to https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ # Binding to port 0 will cause the OS to find an available port for us
+ sock.bind(('', 0))
+ port = sock.getsockname()[1]
+ sock.close()
+ # NOTE: there is still a chance the port could be taken by other processes.
+ return port
+
+def _is_free_port(port):
+ ips = socket.gethostbyname_ex(socket.gethostname())[-1]
+ ips.append('localhost')
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ return all(s.connect_ex((ip, port)) != 0 for ip in ips)
+
+
+def collect_env():
+ """Collect the information of the running environments."""
+ env_info = collect_base_env()
+ env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}'
+
+ return env_info
+
+def init_env(launcher, cfg):
+ """Initialize distributed training environment.
+ If argument ``cfg.dist_params.dist_url`` is specified as 'env://', then the master port will be system
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+ environment variable, then a default port ``29500`` will be used.
+ """
+ if launcher == 'slurm':
+ _init_dist_slurm(cfg)
+ elif launcher == 'ror':
+ _init_dist_ror(cfg)
+ elif launcher == 'None':
+ _init_none_dist(cfg)
+ elif launcher == 'pytorch':
+ _init_dist_pytorch(cfg)
+ else:
+ raise RuntimeError(f'{cfg.launcher} has not been supported!')
+
+def _init_none_dist(cfg):
+ cfg.dist_params.num_gpus_per_node = 1
+ cfg.dist_params.world_size = 1
+ cfg.dist_params.nnodes = 1
+ cfg.dist_params.node_rank = 0
+ cfg.dist_params.global_rank = 0
+ cfg.dist_params.local_rank = 0
+ os.environ["WORLD_SIZE"] = str(1)
+
+def _init_dist_ror(cfg):
+ from ac2.ror.comm import get_local_rank, get_world_rank, get_local_size, get_node_rank, get_world_size
+ cfg.dist_params.num_gpus_per_node = get_local_size()
+ cfg.dist_params.world_size = get_world_size()
+ cfg.dist_params.nnodes = (get_world_size()) // (get_local_size())
+ cfg.dist_params.node_rank = get_node_rank()
+ cfg.dist_params.global_rank = get_world_rank()
+ cfg.dist_params.local_rank = get_local_rank()
+ os.environ["WORLD_SIZE"] = str(get_world_size())
+
+
+def _init_dist_pytorch(cfg):
+ # load env. paras.
+ local_rank = int(os.environ['LOCAL_RANK'])
+ world_size = int(os.environ['WORLD_SIZE'])
+ global_rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+
+ cfg.dist_params.num_gpus_per_node = num_gpus
+ cfg.dist_params.world_size = world_size
+ cfg.dist_params.nnodes = int(world_size // num_gpus)
+ cfg.dist_params.node_rank = int(global_rank % num_gpus)
+ cfg.dist_params.global_rank = global_rank
+
+ os.environ['NODE_RANK'] = str(cfg.dist_params.node_rank)
+ # set dist_url to 'env://'
+ cfg.dist_params.dist_url = 'env://' #f"{master_addr}:{master_port}"
+
+
+def _init_dist_slurm(cfg):
+ if 'NNODES' not in os.environ:
+ os.environ['NNODES'] = str(cfg.dist_params.nnodes)
+ if 'NODE_RANK' not in os.environ:
+ os.environ['NODE_RANK'] = str(cfg.dist_params.node_rank)
+
+ #cfg.dist_params.
+ num_gpus = torch.cuda.device_count()
+ world_size = int(os.environ['NNODES']) * num_gpus
+ os.environ['WORLD_SIZE'] = str(world_size)
+
+ # config port
+ if 'MASTER_PORT' in os.environ:
+ master_port = str(os.environ['MASTER_PORT']) # use MASTER_PORT in the environment variable
+ else:
+ # if torch.distributed default port(29500) is available
+ # then use it, else find a free port
+ if _is_free_port(16500):
+ master_port = '16500'
+ else:
+ master_port = str(_find_free_port())
+ os.environ['MASTER_PORT'] = master_port
+
+ # config addr
+ if 'MASTER_ADDR' in os.environ:
+ master_addr = str(os.environ['MASTER_PORT']) # use MASTER_PORT in the environment variable
+ # elif cfg.dist_params.dist_url is not None:
+ # master_addr = ':'.join(str(cfg.dist_params.dist_url).split(':')[:2])
+ else:
+ master_addr = '127.0.0.1' #'tcp://127.0.0.1'
+ os.environ['MASTER_ADDR'] = master_addr
+
+ # set dist_url to 'env://'
+ cfg.dist_params.dist_url = 'env://' #f"{master_addr}:{master_port}"
+
+ cfg.dist_params.num_gpus_per_node = num_gpus
+ cfg.dist_params.world_size = world_size
+ cfg.dist_params.nnodes = int(os.environ['NNODES'])
+ cfg.dist_params.node_rank = int(os.environ['NODE_RANK'])
+
+ # if int(os.environ['NNODES']) > 1 and cfg.dist_params.dist_url.startswith("file://"):
+ # raise Warning("file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://")
+
+
+def get_func(func_name):
+ """
+ Helper to return a function object by name. func_name must identify
+ a function in this module or the path to a function relative to the base
+ module.
+ @ func_name: function name.
+ """
+ if func_name == '':
+ return None
+ try:
+ parts = func_name.split('.')
+ # Refers to a function in this module
+ if len(parts) == 1:
+ return globals()[parts[0]]
+ # Otherwise, assume we're referencing a module under modeling
+ module_name = '.'.join(parts[:-1])
+ module = importlib.import_module(module_name)
+ return getattr(module, parts[-1])
+ except:
+ raise RuntimeError(f'Failed to find function: {func_name}')
+
+class Timer(object):
+ """A simple timer."""
+
+ def __init__(self):
+ self.reset()
+
+ def tic(self):
+ # using time.time instead of time.clock because time time.clock
+ # does not normalize for multithreading
+ self.start_time = time.time()
+
+ def toc(self, average=True):
+ self.diff = time.time() - self.start_time
+ self.total_time += self.diff
+ self.calls += 1
+ self.average_time = self.total_time / self.calls
+ if average:
+ return self.average_time
+ else:
+ return self.diff
+
+ def reset(self):
+ self.total_time = 0.
+ self.calls = 0
+ self.start_time = 0.
+ self.diff = 0.
+ self.average_time = 0.
+
+class TrainingStats(object):
+ """Track vital training statistics."""
+ def __init__(self, log_period, tensorboard_logger=None):
+ self.log_period = log_period
+ self.tblogger = tensorboard_logger
+ self.tb_ignored_keys = ['iter', 'eta', 'epoch', 'time', 'val_err']
+ self.iter_timer = Timer()
+ # Window size for smoothing tracked values (with median filtering)
+ self.filter_size = log_period
+ def create_smoothed_value():
+ return AverageMeter()
+ self.smoothed_losses = defaultdict(create_smoothed_value)
+ #self.smoothed_metrics = defaultdict(create_smoothed_value)
+ #self.smoothed_total_loss = AverageMeter()
+
+
+ def IterTic(self):
+ self.iter_timer.tic()
+
+ def IterToc(self):
+ return self.iter_timer.toc(average=False)
+
+ def reset_iter_time(self):
+ self.iter_timer.reset()
+
+ def update_iter_stats(self, losses_dict):
+ """Update tracked iteration statistics."""
+ for k, v in losses_dict.items():
+ self.smoothed_losses[k].update(float(v), 1)
+
+ def log_iter_stats(self, cur_iter, optimizer, max_iters, val_err={}):
+ """Log the tracked statistics."""
+ if (cur_iter % self.log_period == 0):
+ stats = self.get_stats(cur_iter, optimizer, max_iters, val_err)
+ log_stats(stats)
+ if self.tblogger:
+ self.tb_log_stats(stats, cur_iter)
+ for k, v in self.smoothed_losses.items():
+ v.reset()
+ self.iter_timer.reset() # reset time counting every log period
+
+ def tb_log_stats(self, stats, cur_iter):
+ """Log the tracked statistics to tensorboard"""
+ for k in stats:
+ # ignore some logs
+ if k not in self.tb_ignored_keys:
+ v = stats[k]
+ if isinstance(v, dict):
+ self.tb_log_stats(v, cur_iter)
+ else:
+ self.tblogger.add_scalar(k, v, cur_iter)
+
+
+ def get_stats(self, cur_iter, optimizer, max_iters, val_err = {}):
+ eta_seconds = self.iter_timer.average_time * (max_iters - cur_iter)
+
+ eta = str(datetime.timedelta(seconds=int(eta_seconds)))
+ stats = OrderedDict(
+ iter=cur_iter, # 1-indexed
+ time=self.iter_timer.average_time,
+ eta=eta,
+ )
+ optimizer_state_dict = optimizer.state_dict()
+ lr = {}
+ for i in range(len(optimizer_state_dict['param_groups'])):
+ lr_name = 'group%d_lr' % i
+ lr[lr_name] = optimizer_state_dict['param_groups'][i]['lr']
+
+ stats['lr'] = OrderedDict(lr)
+ for k, v in self.smoothed_losses.items():
+ stats[k] = v.avg
+
+ stats['val_err'] = OrderedDict(val_err)
+ stats['max_iters'] = max_iters
+ return stats
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Reduce the values in the dictionary from all processes so that process with rank
+ 0 has the reduced results.
+ Args:
+ @input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
+ @average (bool): whether to do average or sum
+ Returns:
+ a dict with the same keys as input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.reduce(values, dst=0)
+ if dist.get_rank() == 0 and average:
+ # only main process gets accumulated, so only divide by
+ # world_size in this case
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
+
+
+def log_stats(stats):
+ logger = logging.getLogger()
+ """Log training statistics to terminal"""
+ lines = "[Step %d/%d]\n" % (
+ stats['iter'], stats['max_iters'])
+
+ lines += "\t\tloss: %.3f, time: %.6f, eta: %s\n" % (
+ stats['total_loss'], stats['time'], stats['eta'])
+
+ # log loss
+ lines += "\t\t"
+ for k, v in stats.items():
+ if 'loss' in k.lower() and 'total_loss' not in k.lower():
+ lines += "%s: %.3f" % (k, v) + ", "
+ lines = lines[:-3]
+ lines += '\n'
+
+ # validate criteria
+ lines += "\t\tlast val err:" + ", ".join("%s: %.6f" % (k, v) for k, v in stats['val_err'].items()) + ", "
+ lines += '\n'
+
+ # lr in different groups
+ lines += "\t\t" + ", ".join("%s: %.8f" % (k, v) for k, v in stats['lr'].items())
+ lines += '\n'
+ logger.info(lines[:-1]) # remove last new linen_pxl
+
diff --git a/training/mono/utils/db.py b/training/mono/utils/db.py
new file mode 100644
index 0000000000000000000000000000000000000000..164d9acaccd9cab6b4b3def26bc78cc692acd0a5
--- /dev/null
+++ b/training/mono/utils/db.py
@@ -0,0 +1,36 @@
+from types import ModuleType
+import data_server_info # data infomation on some server
+
+def load_data_info(module_name, data_info={}, db_type='db_info', module=None):
+ if module is None:
+ module = globals().get(module_name, None)
+ if module:
+ for key, value in module.__dict__.items():
+
+ if not (key.startswith('__')) and not (key.startswith('_')):
+ if key == 'db_info':
+ data_info.update(value)
+ elif isinstance(value, ModuleType):
+ load_data_info(module_name + '.' + key, data_info, module=value)
+ else:
+ raise RuntimeError(f'Try to access "db_info", but cannot find {module_name} module.')
+
+def reset_ckpt_path(cfg, data_info):
+ if isinstance(cfg, dict):
+ for key in cfg.keys():
+ if key == 'backbone':
+ new_ckpt_path = data_info['checkpoint']['db_root'] + '/' + data_info['checkpoint'][cfg.backbone.type]
+ cfg.backbone.update(checkpoint=new_ckpt_path)
+ continue
+ elif isinstance(cfg.get(key), dict):
+ reset_ckpt_path(cfg.get(key), data_info)
+ else:
+ continue
+ else:
+ return
+
+if __name__ == '__main__':
+ db_info_tmp = {}
+ load_data_info('db_data_info', db_info_tmp)
+ print('results', db_info_tmp.keys())
+
diff --git a/training/mono/utils/do_test.py b/training/mono/utils/do_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e34fc2daabf21bd3774fcdeb5f08a749a1f57823
--- /dev/null
+++ b/training/mono/utils/do_test.py
@@ -0,0 +1,245 @@
+import torch
+import logging
+import os
+from mono.utils.avg_meter import MetricAverageMeter
+from mono.utils.visualization import save_val_imgs, visual_train_data, create_html, save_raw_imgs, save_normal_val_imgs
+import cv2
+from tqdm import tqdm
+import numpy as np
+from mono.utils.logger import setup_logger
+from mono.utils.comm import main_process
+#from scipy.optimize import minimize
+#from torchmin import minimize
+import torch.optim as optim
+from torch.autograd import Variable
+
+
+def to_cuda(data: dict):
+ for k, v in data.items():
+ if isinstance(v, torch.Tensor):
+ data[k] = v.cuda(non_blocking=True)
+ if isinstance(v, list) and len(v)>=1 and isinstance(v[0], torch.Tensor):
+ for i, l_i in enumerate(v):
+ data[k][i] = l_i.cuda(non_blocking=True)
+ return data
+
+def align_scale(pred: torch.tensor, target: torch.tensor):
+ mask = target > 0
+ if torch.sum(mask) > 10:
+ scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8)
+ else:
+ scale = 1
+ pred_scale = pred * scale
+ return pred_scale, scale
+
+def align_shift(pred: torch.tensor, target: torch.tensor):
+ mask = target > 0
+ if torch.sum(mask) > 10:
+ shift = torch.median(target[mask]) - (torch.median(pred[mask]) + 1e-8)
+ else:
+ shift = 0
+ pred_shift = pred + shift
+ return pred_shift, shift
+
+def align_scale_shift(pred: torch.tensor, target: torch.tensor):
+ mask = target > 0
+ target_mask = target[mask].cpu().numpy()
+ pred_mask = pred[mask].cpu().numpy()
+ if torch.sum(mask) > 10:
+ scale, shift = np.polyfit(pred_mask, target_mask, deg=1)
+ if scale < 0:
+ scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8)
+ shift = 0
+ else:
+ scale = 1
+ shift = 0
+ pred = pred * scale + shift
+ return pred, scale
+
+def get_prediction(
+ model: torch.nn.Module,
+ input: torch.tensor,
+ cam_model: torch.tensor,
+ pad_info: torch.tensor,
+ scale_info: torch.tensor,
+ gt_depth: torch.tensor,
+ normalize_scale: float,
+ intrinsic = None,
+ clip_range = None,
+ flip_aug = False):
+ #clip_range = [0, 10],
+ #flip_aug = True):
+
+ data = dict(
+ input=input,
+ #ref_input=ref_input,
+ cam_model=cam_model
+ )
+ #output = model.module.inference(data)
+ output = model.module.inference(data)
+ pred_depth, confidence = output['prediction'], output['confidence']
+ pred_depth = torch.abs(pred_depth)
+ pred_depth = pred_depth.squeeze()
+
+ if flip_aug == True:
+ output_flip = model.module.inference(dict(
+ input=torch.flip(input, [3]),
+ #ref_input=ref_input,
+ cam_model=cam_model
+ ))
+
+ if clip_range != None:
+ output['prediction'] = torch.clamp(output['prediction'], clip_range[0], clip_range[1])
+ output_flip['prediction'] = torch.clamp(output_flip['prediction'], clip_range[0] / normalize_scale * scale_info , clip_range[1] / normalize_scale * scale_info)
+
+ output['prediction'] = 0.5 * (output['prediction'] + torch.flip(output_flip['prediction'], [3]))
+ output['confidence'] = 0.5 * (output['confidence'] + torch.flip(output_flip['confidence'], [3]))
+
+ output['pad'] = torch.Tensor(pad_info).cuda().unsqueeze(0).int()
+ output['mask'] = torch.ones_like(pred_depth).bool().unsqueeze(0).unsqueeze(1)
+ output['scale_info'] = scale_info
+ if intrinsic is not None:
+ output['intrinsic'] = intrinsic
+
+ pred_depth = pred_depth[pad_info[0]: pred_depth.shape[0]-pad_info[1], pad_info[2]: pred_depth.shape[1]-pad_info[3]]
+ pred_depth = torch.nn.functional.interpolate(pred_depth[None, None, :, :], gt_depth.shape, mode='bilinear').squeeze() # to orginal size
+ pred_depth = pred_depth * normalize_scale / scale_info
+
+ if clip_range != None:
+ pred_depth = torch.clamp(pred_depth, clip_range[0], clip_range[1])
+
+ pred_depth_scale, scale = align_scale(pred_depth, gt_depth) #align_scale_shift(pred_depth, gt_depth)
+
+ if clip_range != None:
+ pred_depth_scale = torch.clamp(pred_depth_scale, clip_range[0], clip_range[1])
+
+ return pred_depth, pred_depth_scale, scale, output
+
+
+# def depth_normal_consistency_optimization(output_dict, consistency_fn):
+# s = torch.zeros_like(output_dict['scale_info'])
+# def closure(x):
+# output_dict['scale'] = torch.exp(x) * output_dict['scale_info']
+# error = consistency_fn(**output_dict)
+# return error + x * x
+
+# result = minimize(closure, s, method='newton-exact', disp=1, options={'max_iter':10, 'lr':0.1})
+# return float(torch.exp(-result.x))
+
+
+def do_test_with_dataloader(
+ model: torch.nn.Module,
+ cfg: dict,
+ dataloader: torch.utils.data,
+ logger: logging.RootLogger,
+ is_distributed: bool = True,
+ local_rank: int = 0):
+
+ show_dir = cfg.show_dir
+ save_interval = 100
+ save_html_path = show_dir + '/index.html'
+ save_imgs_dir = show_dir + '/vis'
+ os.makedirs(save_imgs_dir, exist_ok=True)
+ save_raw_dir = show_dir + '/raw'
+ os.makedirs(save_raw_dir, exist_ok=True)
+
+ normalize_scale = cfg.data_basic.depth_range[1]
+
+ dam = MetricAverageMeter(cfg.test_metrics)
+ dam_scale = MetricAverageMeter(cfg.test_metrics)
+
+ try:
+ depth_range = cfg.data_basic.clip_depth_range if cfg.clip_depth else None
+ except:
+ depth_range = None
+
+ for i, data in enumerate(tqdm(dataloader)):
+
+ # logger.info(f'{local_rank}: {i}/{len(dataloader)}')
+ data = to_cuda(data)
+ gt_depth = data['target'].squeeze()
+ mask = gt_depth > 1e-6
+ pad_info = data['pad']
+ pred_depth, pred_depth_scale, scale, output = get_prediction(
+ model,
+ data['input'],
+ data['cam_model'],
+ pad_info,
+ data['scale'],
+ gt_depth,
+ normalize_scale,
+ data['intrinsic'],
+ )
+
+ logger.info(f'{data["filename"]}: {scale}')
+
+ # optimization
+ #if "normal_out_list" in output.keys():
+ #scale_opt = depth_normal_consistency_optimization(output, consistency_loss)
+ #print('scale', scale_opt, float(scale))
+ scale_opt = 1.0
+
+ # update depth metrics
+ dam_scale.update_metrics_gpu(pred_depth_scale, gt_depth, mask, is_distributed)
+ dam.update_metrics_gpu(pred_depth, gt_depth, mask, is_distributed)
+
+ # save evaluation results
+ if i % save_interval == 0:
+ # save
+ rgb = data['input'][:, :, pad_info[0]: data['input'].shape[2]-pad_info[1], pad_info[2]: data['input'].shape[3]-pad_info[3]]
+ rgb = torch.nn.functional.interpolate(rgb, gt_depth.shape, mode='bilinear').squeeze()
+ max_scale = save_val_imgs(i,
+ pred_depth,
+ gt_depth,
+ rgb,
+ data['filename'][0],
+ save_imgs_dir,
+ )
+ logger.info(f'{data["filename"]}, {"max_scale"}: {max_scale}')
+
+ # # save original depth/rgb
+ # save_raw_imgs(
+ # pred_depth.cpu().squeeze().numpy(),
+ # data['raw_rgb'].cpu().squeeze().numpy(),
+ # data['filename'][0],
+ # save_raw_dir,
+ # )
+
+
+ # surface normal metrics
+ if "normal_out_list" in output.keys():
+ normal_out_list = output['normal_out_list']
+ gt_normal = data['normal']
+
+ pred_normal = normal_out_list[-1][:, :3, :, :] # (B, 3, H, W)
+ H, W = pred_normal.shape[2:]
+ pred_normal = pred_normal[:, :, pad_info[0]:H-pad_info[1], pad_info[2]:W-pad_info[3]]
+ pred_normal = torch.nn.functional.interpolate(pred_normal, size=gt_normal.shape[2:], mode='bilinear', align_corners=True)
+
+ gt_normal_mask = ~torch.all(gt_normal == 0, dim=1, keepdim=True)
+ dam.update_normal_metrics_gpu(pred_normal, gt_normal, gt_normal_mask, cfg.distributed)# save valiad normal
+
+ if i % save_interval == 0:
+ save_normal_val_imgs(iter,
+ pred_normal,
+ gt_normal,
+ rgb, # data['input'],
+ 'normal_' + data['filename'][0],
+ save_imgs_dir,
+ )
+
+ # get validation error
+ if main_process():
+ eval_error = dam.get_metrics()
+ print('>>>>>W/o scale: ', eval_error)
+ eval_error_scale = dam_scale.get_metrics()
+ print('>>>>>W scale: ', eval_error_scale)
+ # disp_eval_error = dam_disp.get_metrics()
+ # print('>>>>>Disp to depth: ', disp_eval_error)
+ # for i, dam in enumerate(dams):
+ # print(f'>>>>>W/o scale gru{i}: ', dam.get_metrics())
+
+ logger.info(eval_error)
+ logger.info(eval_error_scale)
+ # logger.info(disp_eval_error)
+ # [logger.info(dam.get_metrics()) for dam in dams]
diff --git a/training/mono/utils/do_train.py b/training/mono/utils/do_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..69bb1e009418d832ae27033f0a7532a003492e59
--- /dev/null
+++ b/training/mono/utils/do_train.py
@@ -0,0 +1,529 @@
+import os
+import torch
+import matplotlib.pyplot as plt
+from mono.model.monodepth_model import get_configured_monodepth_model
+from tensorboardX import SummaryWriter
+from mono.utils.comm import TrainingStats
+from mono.utils.avg_meter import MetricAverageMeter
+from mono.utils.running import build_lr_schedule_with_cfg, build_optimizer_with_cfg, load_ckpt, save_ckpt
+from mono.utils.comm import reduce_dict, main_process, get_rank
+from mono.utils.visualization import save_val_imgs, visual_train_data, create_html, save_normal_val_imgs
+import traceback
+from mono.utils.visualization import create_dir_for_validate_meta
+from mono.model.criterion import build_criterions
+from mono.datasets.distributed_sampler import build_dataset_n_sampler_with_cfg, build_data_array
+from mono.utils.logger import setup_logger
+import logging
+from .misc import NativeScalerWithGradNormCount, is_bf16_supported
+import math
+import sys
+import random
+import numpy as np
+import torch.distributed as dist
+import torch.nn.functional as F
+from contextlib import nullcontext
+
+def to_cuda(data):
+ for k, v in data.items():
+ if isinstance(v, torch.Tensor):
+ data[k] = v.cuda(non_blocking=True)
+ if isinstance(v, list) and len(v)>1 and isinstance(v[0], torch.Tensor):
+ for i, l_i in enumerate(v):
+ data[k][i] = l_i.cuda(non_blocking=True)
+ return data
+
+def do_train(local_rank: int, cfg: dict):
+
+ logger = setup_logger(cfg.log_file)
+
+ # build criterions
+ criterions = build_criterions(cfg)
+
+ # build model
+ model = get_configured_monodepth_model(cfg,
+ criterions,
+ )
+
+ # log model state_dict
+ if main_process():
+ logger.info(model.state_dict().keys())
+
+ # build datasets
+ train_dataset, train_sampler = build_dataset_n_sampler_with_cfg(cfg, 'train')
+ if 'multi_dataset_eval' in cfg.evaluation and cfg.evaluation.multi_dataset_eval:
+ val_dataset = build_data_array(cfg, 'val')
+ else:
+ val_dataset, val_sampler = build_dataset_n_sampler_with_cfg(cfg, 'val')
+ # build data loaders
+ g = torch.Generator()
+ g.manual_seed(cfg.seed + cfg.dist_params.global_rank)
+ train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset,
+ batch_size=cfg.batchsize_per_gpu,
+ num_workers=cfg.thread_per_gpu,
+ sampler=train_sampler,
+ drop_last=True,
+ pin_memory=True,
+ generator=g,)
+ # collate_fn=collate_fn)
+ if isinstance(val_dataset, list):
+ val_dataloader = [torch.utils.data.DataLoader(dataset=val_dataset,
+ batch_size=1,
+ num_workers=0,
+ sampler=torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False),
+ drop_last=True,
+ pin_memory=True,) for val_group in val_dataset for val_dataset in val_group]
+ else:
+ val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset,
+ batch_size=1,
+ num_workers=0,
+ sampler=val_sampler,
+ drop_last=True,
+ pin_memory=True,)
+
+ # build schedule
+ lr_scheduler = build_lr_schedule_with_cfg(cfg)
+ optimizer = build_optimizer_with_cfg(cfg, model)
+
+ # config distributed training
+ if cfg.distributed:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ model = torch.nn.parallel.DistributedDataParallel(model.cuda(),
+ device_ids=[local_rank],
+ output_device=local_rank,
+ find_unused_parameters=False)
+ else:
+ model = torch.nn.DataParallel(model.cuda())
+
+ # init automatic mix precision training
+ # if 'AMP' in cfg.runner.type:
+ # loss_scaler = NativeScalerWithGradNormCount()
+ # else:
+ # loss_scaler = None
+ loss_scaler = None
+
+ # load ckpt
+ if cfg.load_from and cfg.resume_from is None:
+ model, _, _, loss_scaler = load_ckpt(cfg.load_from, model, optimizer=None, scheduler=None, strict_match=False, loss_scaler=loss_scaler)
+ elif cfg.resume_from:
+ model, optimizer, lr_scheduler, loss_scaler = load_ckpt(
+ cfg.resume_from,
+ model,
+ optimizer=optimizer,
+ scheduler=lr_scheduler,
+ strict_match=False,
+ loss_scaler=loss_scaler)
+
+ if cfg.runner.type == 'IterBasedRunner':
+ train_by_iters(cfg,
+ model,
+ optimizer,
+ lr_scheduler,
+ train_dataloader,
+ val_dataloader,
+ )
+ elif cfg.runner.type == 'IterBasedRunner_MultiSize':
+ train_by_iters_multisize(cfg,
+ model,
+ optimizer,
+ lr_scheduler,
+ train_dataloader,
+ val_dataloader,
+ )
+ elif cfg.runner.type == 'IterBasedRunner_AMP':
+ train_by_iters_amp(
+ cfg = cfg,
+ model=model,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ train_dataloader=train_dataloader,
+ val_dataloader=val_dataloader,
+ loss_scaler=loss_scaler
+ )
+ elif cfg.runner.type == 'IterBasedRunner_AMP_MultiSize':
+ train_by_iters_amp_multisize(
+ cfg = cfg,
+ model=model,
+ optimizer=optimizer,
+ lr_scheduler=lr_scheduler,
+ train_dataloader=train_dataloader,
+ val_dataloader=val_dataloader,
+ loss_scaler=loss_scaler
+ )
+ elif cfg.runner.type == 'EpochBasedRunner':
+ raise RuntimeError('It is not supported currently. :)')
+ else:
+ raise RuntimeError('It is not supported currently. :)')
+
+
+def train_by_iters(cfg, model, optimizer, lr_scheduler, train_dataloader, val_dataloader):
+ """
+ Do the training by iterations.
+ """
+ logger = logging.getLogger()
+ tb_logger = None
+ if cfg.use_tensorboard and main_process():
+ tb_logger = SummaryWriter(cfg.tensorboard_dir)
+ if main_process():
+ training_stats = TrainingStats(log_period=cfg.log_interval, tensorboard_logger=tb_logger)
+
+ lr_scheduler.before_run(optimizer)
+
+ # set training steps
+ max_iters = cfg.runner.max_iters
+ start_iter = lr_scheduler._step_count
+
+ save_interval = cfg.checkpoint_config.interval
+ eval_interval = cfg.evaluation.interval
+ epoch = 0
+ logger.info('Create iterator.')
+ dataloader_iterator = iter(train_dataloader)
+
+ val_err = {}
+ logger.info('Start training.')
+
+ try:
+ # for step in range(start_iter, max_iters):
+ # keep same step in all processes, avoid stuck during eval barrier
+ step = start_iter
+ while step < max_iters:
+ if main_process():
+ training_stats.IterTic()
+
+ # get the data batch
+ try:
+ data = next(dataloader_iterator)
+ except StopIteration:
+ dataloader_iterator = iter(train_dataloader)
+ data = next(dataloader_iterator)
+ except Exception as e:
+ logger.info('When load training data: ', e)
+ continue
+ except:
+ logger.info('Some training data errors exist in the current iter!')
+ continue
+ data = to_cuda(data)
+ # set random crop size
+ # if step % 10 == 0:
+ # set_random_crop_size_for_iter(train_dataloader, step, size_sample_list[step])
+
+ # check training data
+ #for i in range(data['target'].shape[0]):
+ # if 'DDAD' in data['dataset'][i] or \
+ # 'Lyft' in data['dataset'][i] or \
+ # 'DSEC' in data['dataset'][i] or \
+ # 'Argovers2' in data['dataset'][i]:
+ # replace = True
+ # else:
+ # replace = False
+ #visual_train_data(data['target'][i, ...], data['input'][i,...], data['filename'][i], cfg.work_dir, replace=replace)
+
+ # forward
+ pred_depth, losses_dict, conf = model(data)
+
+ optimizer.zero_grad()
+ losses_dict['total_loss'].backward()
+ # if step > 100 and step % 10 == 0:
+ # for param in model.parameters():
+ # print(param.grad.max(), torch.norm(param.grad))
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
+ optimizer.step()
+
+ # reduce losses over all GPUs for logging purposes
+ loss_dict_reduced = reduce_dict(losses_dict)
+
+ lr_scheduler.after_train_iter(optimizer)
+ if main_process():
+ training_stats.update_iter_stats(loss_dict_reduced)
+ training_stats.IterToc()
+ training_stats.log_iter_stats(step, optimizer, max_iters, val_err)
+
+ # validate the model
+ if cfg.evaluation.online_eval and \
+ (step+1) % eval_interval == 0 and \
+ val_dataloader is not None:
+ if isinstance(val_dataloader, list):
+ val_err = validate_multiple_dataset(cfg, step+1, model, val_dataloader, tb_logger)
+ else:
+ val_err = validate(cfg, step+1, model, val_dataloader, tb_logger)
+ if main_process():
+ training_stats.tb_log_stats(val_err, step)
+
+ # save checkpoint
+ if main_process():
+ if ((step+1) % save_interval == 0) or ((step+1)==max_iters):
+ save_ckpt(cfg, model, optimizer, lr_scheduler, step+1, epoch)
+
+ step += 1
+
+ except (RuntimeError, KeyboardInterrupt):
+ stack_trace = traceback.format_exc()
+ print(stack_trace)
+
+def train_by_iters_amp(cfg, model, optimizer, lr_scheduler, train_dataloader, val_dataloader, loss_scaler):
+ """
+ Do the training by iterations.
+ Mix precision is employed.
+ """
+ # set up logger
+ tb_logger = None
+ if cfg.use_tensorboard and main_process():
+ tb_logger = SummaryWriter(cfg.tensorboard_dir)
+ logger = logging.getLogger()
+ # training status
+ if main_process():
+ training_stats = TrainingStats(log_period=cfg.log_interval, tensorboard_logger=tb_logger)
+
+ # learning schedule
+ lr_scheduler.before_run(optimizer)
+
+ # set training steps
+ max_iters = cfg.runner.max_iters
+ start_iter = lr_scheduler._step_count
+
+ save_interval = cfg.checkpoint_config.interval
+ eval_interval = cfg.evaluation.interval
+ epoch = 0
+
+ # If it's too slow try lowering num_worker
+ # see https://discuss.pytorch.org/t/define-iterator-on-dataloader-is-very-slow/52238
+ logger.info('Create iterator.')
+ dataloader_iterator = iter(train_dataloader)
+
+ val_err = {}
+ # torch.cuda.empty_cache()
+ logger.info('Start training.')
+
+ try:
+ acc_batch = cfg.acc_batch
+ except:
+ acc_batch = 1
+
+ try:
+ # for step in range(start_iter, max_iters):
+ # keep same step in all processes, avoid stuck during eval barrier
+ step = start_iter * acc_batch
+ #while step < max_iters:
+ while True:
+
+ if main_process():
+ training_stats.IterTic()
+
+ # get the data batch
+ try:
+ data = next(dataloader_iterator)
+ except StopIteration:
+ dataloader_iterator = iter(train_dataloader)
+ data = next(dataloader_iterator)
+ except Exception as e:
+ logger.info('When load training data: ', e)
+ continue
+ except:
+ logger.info('Some training data errors exist in the current iter!')
+ continue
+
+ data = to_cuda(data)
+
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
+ pred_depth, losses_dict, conf = model(data)
+
+ total_loss = losses_dict['total_loss'] / acc_batch
+
+ if not math.isfinite(total_loss):
+ logger.info("Loss is {}, skiping this batch training".format(total_loss))
+ continue
+
+ # optimize, backward
+ if (step+1-start_iter) % acc_batch == 0:
+ optimizer.zero_grad()
+ if loss_scaler == None:
+ total_loss.backward()
+ try:
+ if (step+1-start_iter) % acc_batch == 0:
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 2.5, error_if_nonfinite=True)
+ optimizer.step()
+ except:
+ print('NAN gradient, skipping optimizer.step() for this round...')
+ else:
+ loss_scaler(total_loss, optimizer, clip_grad=5, parameters=model.parameters(), update_grad=True)
+
+ # reduce losses over all GPUs for logging purposes
+ if (step+1-start_iter) % acc_batch == 0:
+ loss_dict_reduced = reduce_dict(losses_dict)
+ lr_scheduler.after_train_iter(optimizer)
+
+ if main_process():
+ training_stats.update_iter_stats(loss_dict_reduced)
+ training_stats.IterToc()
+ training_stats.log_iter_stats(step//acc_batch, optimizer, max_iters, val_err)
+
+ # validate the model
+ if cfg.evaluation.online_eval and \
+ ((step+acc_batch)//acc_batch) % eval_interval == 0 and \
+ val_dataloader is not None:
+ # if True:
+ if isinstance(val_dataloader, list):
+ val_err = validate_multiple_dataset(cfg, ((step+acc_batch)//acc_batch), model, val_dataloader, tb_logger)
+ else:
+ val_err = validate(cfg, ((step+acc_batch)//acc_batch), model, val_dataloader, tb_logger)
+ if main_process():
+ training_stats.tb_log_stats(val_err, step)
+
+ # save checkpoint
+ if main_process():
+ if (((step+acc_batch)//acc_batch) % save_interval == 0) or (((step+acc_batch)//acc_batch)==max_iters):
+ save_ckpt(cfg, model, optimizer, lr_scheduler, ((step+acc_batch)//acc_batch), epoch, loss_scaler=loss_scaler)
+
+ step += 1
+
+
+ except (RuntimeError, KeyboardInterrupt):
+ stack_trace = traceback.format_exc()
+ print(stack_trace)
+
+def validate_multiple_dataset(cfg, iter, model, val_dataloaders, tb_logger):
+ val_errs = {}
+ for val_dataloader in val_dataloaders:
+ val_err = validate(cfg, iter, model, val_dataloader, tb_logger)
+ val_errs.update(val_err)
+ # mean of all dataset
+ mean_val_err = {}
+ for k, v in val_errs.items():
+ metric = 'AllData_eval/' + k.split('/')[-1]
+ if metric not in mean_val_err.keys():
+ mean_val_err[metric] = 0
+ mean_val_err[metric] += v / len(val_dataloaders)
+ val_errs.update(mean_val_err)
+
+ return val_errs
+
+
+def validate(cfg, iter, model, val_dataloader, tb_logger):
+ """
+ Validate the model on single dataset
+ """
+ model.eval()
+ dist.barrier()
+ logger = logging.getLogger()
+ # prepare dir for visualization data
+ save_val_meta_data_dir = create_dir_for_validate_meta(cfg.work_dir, iter)
+ # save_html_path = save_val_meta_data_dir + '.html'
+ dataset_name = val_dataloader.dataset.data_name
+
+ save_point = max(int(len(val_dataloader) / 5), 1)
+ # save_point = 2
+ # depth metric meter
+ dam = MetricAverageMeter(cfg.evaluation.metrics)
+ # dam_disp = MetricAverageMeter([m for m in cfg.evaluation.metrics if m[:6]!='normal'])
+ for i, data in enumerate(val_dataloader):
+ if i % 10 == 0:
+ logger.info(f'Validation step on {dataset_name}: {i}')
+ data = to_cuda(data)
+ output = model.module.inference(data)
+ pred_depth = output['prediction']
+ pred_depth = pred_depth.squeeze()
+ gt_depth = data['target'].cuda(non_blocking=True).squeeze()
+
+ pad = data['pad'].squeeze()
+ H, W = pred_depth.shape
+ pred_depth = pred_depth[pad[0]:H-pad[1], pad[2]:W-pad[3]]
+ gt_depth = gt_depth[pad[0]:H-pad[1], pad[2]:W-pad[3]]
+ rgb = data['input'][0, :, pad[0]:H-pad[1], pad[2]:W-pad[3]]
+ mask = gt_depth > 0
+ #pred_depth_resize = cv2.resize(pred_depth.cpu().numpy(), (torch.squeeze(data['B_raw']).shape[1], torch.squeeze(data['B_raw']).shape[0]))
+ dam.update_metrics_gpu(pred_depth, gt_depth, mask, cfg.distributed)
+
+ # save evaluation results
+ if i%save_point == 0 and main_process():
+ save_val_imgs(iter,
+ pred_depth,
+ gt_depth,
+ rgb, # data['input'],
+ dataset_name + '_' + data['filename'][0],
+ save_val_meta_data_dir,
+ tb_logger=tb_logger)
+
+ ## surface normal
+ if "normal_out_list" in output.keys():
+ normal_out_list = output['normal_out_list']
+ pred_normal = normal_out_list[-1][:, :3, :, :] # (B, 3, H, W)
+ gt_normal = data['normal'].cuda(non_blocking=True)
+ # if pred_normal.shape != gt_normal.shape:
+ # pred_normal = F.interpolate(pred_normal, size=[gt_normal.size(2), gt_normal.size(3)], mode='bilinear', align_corners=True)
+
+ H, W = pred_normal.shape[2:]
+ pred_normal = pred_normal[:, :, pad[0]:H-pad[1], pad[2]:W-pad[3]]
+ gt_normal = gt_normal[:, :, pad[0]:H-pad[1], pad[2]:W-pad[3]]
+ gt_normal_mask = ~torch.all(gt_normal == 0, dim=1, keepdim=True)
+ dam.update_normal_metrics_gpu(pred_normal, gt_normal, gt_normal_mask, cfg.distributed)
+
+ # save valiad normal
+ if i%save_point == 0 and main_process():
+ save_normal_val_imgs(iter,
+ pred_normal,
+ gt_normal,
+ rgb, # data['input'],
+ dataset_name + '_normal_' + data['filename'][0],
+ save_val_meta_data_dir,
+ tb_logger=tb_logger)
+
+ # create html for visualization
+ merged_rgb_pred_gt = os.path.join(save_val_meta_data_dir, '*_merge.jpg')
+ name2path = dict(merg=merged_rgb_pred_gt) #dict(rgbs=rgbs, pred=pred, gt=gt)
+ # if main_process():
+ # create_html(name2path, save_path=save_html_path, size=(256*3, 512))
+
+ # get validation error
+ eval_error = dam.get_metrics()
+ eval_error = {f'{dataset_name}_eval/{k}': v for k,v in eval_error.items()}
+ # eval_disp_error = {f'{dataset_name}_eval/disp_{k}': v for k,v in dam_disp.get_metrics().items()}
+ # eval_error.update(eval_disp_error)
+
+ model.train()
+
+ if 'exclude' in cfg.evaluation and dataset_name in cfg.evaluation.exclude:
+ return {}
+ return eval_error
+
+def set_random_crop_size_for_iter(dataloader: torch.utils.data.dataloader.DataLoader, iter: int, size_pool=None):
+ if size_pool is None:
+ size_pool = [
+ # [504, 504], [560, 1008], [840, 1512], [1120, 2016],
+ [560, 1008], [840, 1512], [1120, 2016],
+ # [480, 768], [480, 960],
+ # [480, 992], [480, 1024],
+ # [480, 1120],
+ # [480, 1280],
+ # [480, 1312],
+ # [512, 512], [512, 640],
+ # [512, 960],
+ # [512, 992],
+ # [512, 1024], [512, 1120],
+ # [512, 1216],
+ # [512, 1280],
+ # [576, 640], [576, 960],
+ # [576, 992],
+ # [576, 1024],
+ # [608, 608], [608, 640],
+ # [608, 960], [608, 1024],
+ ]
+ random.seed(iter)
+ sample = random.choice(size_pool)
+ # idx = (iter // 10) % len(size_pool)
+ #sample = size_pool[size_idx]
+
+ # random.seed(iter)
+ # flg = random.random() <= 1.0
+ # if flg:
+ crop_size = sample
+ # else:
+ # crop_size = [sample[1], sample[0]]
+
+ # set crop size for each dataset
+ datasets_groups = len(dataloader.dataset.datasets)
+ for i in range(datasets_groups):
+ for j in range(len(dataloader.dataset.datasets[i].datasets)):
+ dataloader.dataset.datasets[i].datasets[j].set_random_crop_size(crop_size)
+ return crop_size
+
+
+
\ No newline at end of file
diff --git a/training/mono/utils/inverse_warp.py b/training/mono/utils/inverse_warp.py
new file mode 100644
index 0000000000000000000000000000000000000000..9511b77e99988a9e9f7a2af766d5bedc47ce1aa7
--- /dev/null
+++ b/training/mono/utils/inverse_warp.py
@@ -0,0 +1,316 @@
+import torch
+import torch.nn.functional as F
+
+pixel_coords = None
+
+def set_id_grid(depth):
+ global pixel_coords
+ b, h, w = depth.size()
+ i_range = torch.arange(0, h).view(1, h, 1).expand(
+ 1, h, w).type_as(depth) # [1, H, W]
+ j_range = torch.arange(0, w).view(1, 1, w).expand(
+ 1, h, w).type_as(depth) # [1, H, W]
+ ones = torch.ones(1, h, w).type_as(depth)
+
+ pixel_coords = torch.stack((j_range, i_range, ones), dim=1) # [1, 3, H, W]
+
+
+def check_sizes(input, input_name, expected):
+ condition = [input.ndimension() == len(expected)]
+ for i, size in enumerate(expected):
+ if size.isdigit():
+ condition.append(input.size(i) == int(size))
+ assert(all(condition)), "wrong size for {}, expected {}, got {}".format(
+ input_name, 'x'.join(expected), list(input.size()))
+
+
+def pixel2cam(depth, intrinsics_inv):
+ global pixel_coords
+ """Transform coordinates in the pixel frame to the camera frame.
+ Args:
+ depth: depth maps -- [B, H, W]
+ intrinsics_inv: intrinsics_inv matrix for each element of batch -- [B, 3, 3]
+ Returns:
+ array of (u,v,1) cam coordinates -- [B, 3, H, W]
+ """
+ b, h, w = depth.size()
+ if (pixel_coords is None) or pixel_coords.size(2) < h:
+ set_id_grid(depth)
+ current_pixel_coords = pixel_coords[:, :, :h, :w].expand(
+ b, 3, h, w).reshape(b, 3, -1) # [B, 3, H*W]
+ cam_coords = (intrinsics_inv @ current_pixel_coords).reshape(b, 3, h, w)
+ out = depth.unsqueeze(1) * cam_coords
+ return out
+
+
+def cam2pixel(cam_coords, proj_c2p_rot, proj_c2p_tr, padding_mode):
+ """Transform coordinates in the camera frame to the pixel frame.
+ Args:
+ cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 4, H, W]
+ proj_c2p_rot: rotation matrix of cameras -- [B, 3, 4]
+ proj_c2p_tr: translation vectors of cameras -- [B, 3, 1]
+ Returns:
+ array of [-1,1] coordinates -- [B, 2, H, W]
+ """
+ b, _, h, w = cam_coords.size()
+ cam_coords_flat = cam_coords.reshape(b, 3, -1) # [B, 3, H*W]
+ if proj_c2p_rot is not None:
+ pcoords = proj_c2p_rot @ cam_coords_flat
+ else:
+ pcoords = cam_coords_flat
+
+ if proj_c2p_tr is not None:
+ pcoords = pcoords + proj_c2p_tr # [B, 3, H*W]
+ X = pcoords[:, 0]
+ Y = pcoords[:, 1]
+ Z = pcoords[:, 2].clamp(min=1e-3)
+
+ # Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1) [B, H*W]
+ X_norm = 2*(X / Z)/(w-1) - 1
+ Y_norm = 2*(Y / Z)/(h-1) - 1 # Idem [B, H*W]
+
+ pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2]
+ return pixel_coords.reshape(b, h, w, 2)
+
+
+def euler2mat(angle):
+ """Convert euler angles to rotation matrix.
+ Reference: https://github.com/pulkitag/pycaffe-utils/blob/master/rot_utils.py#L174
+ Args:
+ angle: rotation angle along 3 axis (in radians) -- size = [B, 3]
+ Returns:
+ Rotation matrix corresponding to the euler angles -- size = [B, 3, 3]
+ """
+ B = angle.size(0)
+ x, y, z = angle[:, 0], angle[:, 1], angle[:, 2]
+
+ cosz = torch.cos(z)
+ sinz = torch.sin(z)
+
+ zeros = z.detach()*0
+ ones = zeros.detach()+1
+ zmat = torch.stack([cosz, -sinz, zeros,
+ sinz, cosz, zeros,
+ zeros, zeros, ones], dim=1).reshape(B, 3, 3)
+
+ cosy = torch.cos(y)
+ siny = torch.sin(y)
+
+ ymat = torch.stack([cosy, zeros, siny,
+ zeros, ones, zeros,
+ -siny, zeros, cosy], dim=1).reshape(B, 3, 3)
+
+ cosx = torch.cos(x)
+ sinx = torch.sin(x)
+
+ xmat = torch.stack([ones, zeros, zeros,
+ zeros, cosx, -sinx,
+ zeros, sinx, cosx], dim=1).reshape(B, 3, 3)
+
+ rotMat = xmat @ ymat @ zmat
+ return rotMat
+
+
+def quat2mat(quat):
+ """Convert quaternion coefficients to rotation matrix.
+ Args:
+ quat: first three coeff of quaternion of rotation. fourht is then computed to have a norm of 1 -- size = [B, 3]
+ Returns:
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
+ """
+ norm_quat = torch.cat([quat[:, :1].detach()*0 + 1, quat], dim=1)
+ norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
+ w, x, y, z = norm_quat[:, 0], norm_quat[:,
+ 1], norm_quat[:, 2], norm_quat[:, 3]
+
+ B = quat.size(0)
+
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
+ wx, wy, wz = w*x, w*y, w*z
+ xy, xz, yz = x*y, x*z, y*z
+
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
+ 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
+ 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).reshape(B, 3, 3)
+ return rotMat
+
+
+def pose_vec2mat(vec, rotation_mode='euler'):
+ """
+ Convert 6DoF parameters to transformation matrix.
+ Args:s
+ vec: 6DoF parameters in the order of tx, ty, tz, rx, ry, rz -- [B, 6]
+ Returns:
+ A transformation matrix -- [B, 3, 4]
+ """
+ translation = vec[:, :3].unsqueeze(-1) # [B, 3, 1]
+ rot = vec[:, 3:]
+ if rotation_mode == 'euler':
+ rot_mat = euler2mat(rot) # [B, 3, 3]
+ elif rotation_mode == 'quat':
+ rot_mat = quat2mat(rot) # [B, 3, 3]
+ transform_mat = torch.cat([rot_mat, translation], dim=2) # [B, 3, 4]
+ return transform_mat
+
+
+def inverse_warp(img, depth, pose, intrinsics, rotation_mode='euler', padding_mode='zeros'):
+ """
+ Inverse warp a source image to the target image plane.
+ Args:
+ img: the source image (where to sample pixels) -- [B, 3, H, W]
+ depth: depth map of the target image -- [B, H, W]
+ pose: 6DoF pose parameters from target to source -- [B, 6]
+ intrinsics: camera intrinsic matrix -- [B, 3, 3]
+ Returns:
+ projected_img: Source image warped to the target image plane
+ valid_points: Boolean array indicating point validity
+ """
+ check_sizes(img, 'img', 'B3HW')
+ check_sizes(depth, 'depth', 'BHW')
+ check_sizes(pose, 'pose', 'B6')
+ check_sizes(intrinsics, 'intrinsics', 'B33')
+
+ batch_size, _, img_height, img_width = img.size()
+
+ cam_coords = pixel2cam(depth, intrinsics.inverse()) # [B,3,H,W]
+
+ pose_mat = pose_vec2mat(pose, rotation_mode) # [B,3,4]
+
+ # Get projection matrix for tgt camera frame to source pixel frame
+ proj_cam_to_src_pixel = intrinsics @ pose_mat # [B, 3, 4]
+
+ rot, tr = proj_cam_to_src_pixel[:, :, :3], proj_cam_to_src_pixel[:, :, -1:]
+ src_pixel_coords = cam2pixel(
+ cam_coords, rot, tr, padding_mode) # [B,H,W,2]
+ projected_img = F.grid_sample(
+ img, src_pixel_coords, padding_mode=padding_mode)
+
+ valid_points = src_pixel_coords.abs().max(dim=-1)[0] <= 1
+
+ return projected_img, valid_points
+
+
+def cam2pixel2(cam_coords, proj_c2p_rot, proj_c2p_tr, padding_mode):
+ """Transform coordinates in the camera frame to the pixel frame.
+ Args:
+ cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 4, H, W]
+ proj_c2p_rot: rotation matrix of cameras -- [B, 3, 4]
+ proj_c2p_tr: translation vectors of cameras -- [B, 3, 1]
+ Returns:
+ array of [-1,1] coordinates -- [B, 2, H, W]
+ """
+ b, _, h, w = cam_coords.size()
+ cam_coords_flat = cam_coords.reshape(b, 3, -1) # [B, 3, H*W]
+ if proj_c2p_rot is not None:
+ pcoords = proj_c2p_rot @ cam_coords_flat
+ else:
+ pcoords = cam_coords_flat
+
+ if proj_c2p_tr is not None:
+ pcoords = pcoords + proj_c2p_tr # [B, 3, H*W]
+ X = pcoords[:, 0]
+ Y = pcoords[:, 1]
+ Z = pcoords[:, 2].clamp(min=1e-3)
+
+ # Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1) [B, H*W]
+ X_norm = 2*(X / Z)/(w-1) - 1
+ Y_norm = 2*(Y / Z)/(h-1) - 1 # Idem [B, H*W]
+ if padding_mode == 'zeros':
+ X_mask = ((X_norm > 1)+(X_norm < -1)).detach()
+ # make sure that no point in warped image is a combinaison of im and gray
+ X_norm[X_mask] = 2
+ Y_mask = ((Y_norm > 1)+(Y_norm < -1)).detach()
+ Y_norm[Y_mask] = 2
+
+ pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2]
+ return pixel_coords.reshape(b, h, w, 2), Z.reshape(b, 1, h, w)
+
+
+def inverse_warp2(img, depth, ref_depth, pose, intrinsics, padding_mode='zeros'):
+ """
+ Inverse warp a source image to the target image plane.
+ Args:
+ img: the source image (where to sample pixels) -- [B, 3, H, W]
+ depth: depth map of the target image -- [B, 1, H, W]
+ ref_depth: the source depth map (where to sample depth) -- [B, 1, H, W]
+ pose: 6DoF pose parameters from target to source -- [B, 6]
+ intrinsics: camera intrinsic matrix -- [B, 3, 3]
+ Returns:
+ projected_img: Source image warped to the target image plane
+ valid_mask: Float array indicating point validity
+ projected_depth: sampled depth from source image
+ computed_depth: computed depth of source image using the target depth
+ """
+ check_sizes(img, 'img', 'B3HW')
+ check_sizes(depth, 'depth', 'B1HW')
+ check_sizes(ref_depth, 'ref_depth', 'B1HW')
+ check_sizes(pose, 'pose', 'B6')
+ check_sizes(intrinsics, 'intrinsics', 'B33')
+
+ batch_size, _, img_height, img_width = img.size()
+
+ cam_coords = pixel2cam(depth.squeeze(1), intrinsics.inverse()) # [B,3,H,W]
+
+ pose_mat = pose_vec2mat(pose) # [B,3,4]
+
+ # Get projection matrix for tgt camera frame to source pixel frame
+ proj_cam_to_src_pixel = intrinsics @ pose_mat # [B, 3, 4]
+
+ rot, tr = proj_cam_to_src_pixel[:, :, :3], proj_cam_to_src_pixel[:, :, -1:]
+ src_pixel_coords, computed_depth = cam2pixel2(cam_coords, rot, tr, padding_mode) # [B,H,W,2]
+ projected_img = F.grid_sample(img, src_pixel_coords, padding_mode=padding_mode, align_corners=False)
+
+ projected_depth = F.grid_sample(ref_depth, src_pixel_coords, padding_mode=padding_mode, align_corners=False)
+
+ return projected_img, projected_depth, computed_depth
+
+
+def inverse_rotation_warp(img, rot, intrinsics, padding_mode='zeros'):
+
+ b, _, h, w = img.size()
+ cam_coords = pixel2cam(torch.ones(b, h, w).type_as(img), intrinsics.inverse()) # [B,3,H,W]
+
+ rot_mat = euler2mat(rot) # [B, 3, 3]
+
+ # Get projection matrix for tgt camera frame to source pixel frame
+ proj_cam_to_src_pixel = intrinsics @ rot_mat # [B, 3, 3]
+
+ src_pixel_coords, computed_depth = cam2pixel2(cam_coords, proj_cam_to_src_pixel, None, padding_mode) # [B,H,W,2]
+ projected_img = F.grid_sample(img, src_pixel_coords, padding_mode=padding_mode, align_corners=True)
+
+ return projected_img
+
+def grid_to_flow(grid):
+ b, h, w, _ = grid.size()
+ i_range = torch.arange(0, h).view(1, h, 1).expand(1, h, w).type_as(grid) # [1, H, W]
+ j_range = torch.arange(0, w).view(1, 1, w).expand(1, h, w).type_as(grid) # [1, H, W]
+ image_coords = torch.stack((j_range, i_range), dim=1) # [1, 2, H, W]
+
+ flow = torch.zeros_like(grid).type_as(grid)
+ flow[:, :, :, 0] = (grid[:, :, :, 0]+1) / 2 * (w-1)
+ flow[:, :, :, 1] = (grid[:, :, :, 1]+1) / 2 * (h-1)
+ flow = flow.permute([0, 3, 1, 2])
+
+ flow -= image_coords
+
+ return flow
+
+def compute_translation_flow(depth, pose, intrinsics):
+ cam_coords = pixel2cam(depth.squeeze(1), intrinsics.inverse()) # [B,3,H,W]
+
+ pose_mat = pose_vec2mat(pose) # [B,3,4]
+
+ # Get projection matrix for tgt camera frame to source pixel frame
+ proj_cam_to_src_pixel = intrinsics @ pose_mat # [B, 3, 4]
+
+ rot, tr = proj_cam_to_src_pixel[:, :, :3], proj_cam_to_src_pixel[:, :, -1:]
+
+ grid_all, _ = cam2pixel2(cam_coords, rot, tr, padding_mode='zeros') # [B,H,W,2]
+ grid_rot, _ = cam2pixel2(cam_coords, rot, None, padding_mode='zeros') # [B,H,W,2]
+
+ flow_all = grid_to_flow(grid_all)
+ flow_rot = grid_to_flow(grid_rot)
+ flow_tr = (flow_all - flow_rot)
+
+ return flow_tr
+
diff --git a/training/mono/utils/logger.py b/training/mono/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a15e3233680d964551e98d95db83e9194a943ed
--- /dev/null
+++ b/training/mono/utils/logger.py
@@ -0,0 +1,105 @@
+import atexit
+import logging
+import os
+import sys
+import time
+import torch
+from termcolor import colored
+
+__all__ = ["setup_logger", ]
+
+class _ColorfulFormatter(logging.Formatter):
+ def __init__(self, *args, **kwargs):
+ self._root_name = kwargs.pop("root_name") + "."
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
+ if len(self._abbrev_name):
+ self._abbrev_name = self._abbrev_name + "."
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
+
+ def formatMessage(self, record):
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
+ log = super(_ColorfulFormatter, self).formatMessage(record)
+ if record.levelno == logging.WARNING:
+ prefix = colored("WARNING", "red", attrs=["blink"])
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
+ else:
+ return log
+ return prefix + " " + log
+
+
+def setup_logger(
+ output=None, distributed_rank=0, *, name='mono@YvanYin', color=True, abbrev_name=None
+):
+ """
+ Initialize the detectron2 logger and set its verbosity level to "DEBUG".
+ Args:
+ output (str): a file name or a directory to save log. If None, will not save log file.
+ If ends with ".txt" or ".log", assumed to be a file name.
+ Otherwise, logs will be saved to `output/log.txt`.
+ abbrev_name (str): an abbreviation of the module, to avoid long names in logs.
+ Set to "" to not log the root module in logs.
+ By default, will abbreviate "detectron2" to "d2" and leave other
+ modules unchanged.
+ Returns:
+ logging.Logger: a logger
+ """
+ logger = logging.getLogger()
+ logger.setLevel(logging.DEBUG)
+ logger.propagate = False
+
+ if abbrev_name is None:
+ abbrev_name = "d2"
+
+ plain_formatter = logging.Formatter(
+ "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
+ )
+ # stdout logging: master only
+ if distributed_rank == 0:
+ ch = logging.StreamHandler(stream=sys.stdout)
+ ch.setLevel(logging.DEBUG)
+ if color:
+ formatter = _ColorfulFormatter(
+ colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
+ datefmt="%m/%d %H:%M:%S",
+ root_name=name,
+ abbrev_name=str(abbrev_name),
+ )
+ else:
+ formatter = plain_formatter
+ ch.setFormatter(formatter)
+ logger.addHandler(ch)
+
+ # file logging: all workers
+ if output is not None:
+ if output.endswith(".txt") or output.endswith(".log"):
+ filename = output
+ else:
+ filename = os.path.join(output, "log.txt")
+ if distributed_rank > 0:
+ filename = filename + ".rank{}".format(distributed_rank)
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+
+ # fh = logging.FileHandler(output, 'w')
+ fh = logging.StreamHandler(_cached_log_stream(filename))
+ fh.setLevel(logging.DEBUG)
+ fh.setFormatter(plain_formatter)
+ logger.addHandler(fh)
+
+
+ return logger
+
+
+from iopath.common.file_io import PathManager as PathManagerBase
+
+
+
+PathManager = PathManagerBase()
+
+# cache the opened file object, so that different calls to `setup_logger`
+# with the same file name can safely write to the same file.
+def _cached_log_stream(filename):
+ # use 1K buffer if writing to cloud storage
+ io = PathManager.open(filename, "a", buffering=1024 if "://" in filename else -1)
+ atexit.register(io.close)
+ return io
diff --git a/training/mono/utils/logit_to_depth.py b/training/mono/utils/logit_to_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..30dea4ee2d5a1b2da715e0900c65ae8c309f6eb5
--- /dev/null
+++ b/training/mono/utils/logit_to_depth.py
@@ -0,0 +1,58 @@
+import torch
+import torch.nn as nn
+
+class SoftWeight(nn.Module):
+ """
+ Transfer n-channel discrete depth bins to a depth map.
+ Args:
+ @depth_bin: n-channel output of the network, [b, c, h, w]
+ Return: 1-channel depth, [b, 1, h, w]
+ """
+ def __init__(self, depth_bins_border):
+ super(SoftWeight, self).__init__()
+ self.register_buffer("depth_bins_border", torch.tensor(depth_bins_border), persistent=False)
+
+ def forward(self, pred_logit):
+ if type(pred_logit).__module__ != torch.__name__:
+ pred_logit = torch.tensor(pred_logit, dtype=torch.float32, device="cuda")
+ pred_score = nn.functional.softmax(pred_logit, dim=1)
+ pred_score_ch = pred_score.permute(0, 2, 3, 1) #[b, h, w, c]
+ pred_score_weight = pred_score_ch * self.depth_bins_border
+ depth_log = torch.sum(pred_score_weight, dim=3, dtype=torch.float32, keepdim=True)
+ depth = 10 ** depth_log
+ depth = depth.permute(0, 3, 1, 2) # [b, 1, h, w]
+ confidence, _ = torch.max(pred_logit, dim=1, keepdim=True)
+ return depth, confidence
+
+def soft_weight(pred_logit, depth_bins_border):
+ """
+ Transfer n-channel discrete depth bins to depth map.
+ Args:
+ @depth_bin: n-channel output of the network, [b, c, h, w]
+ Return: 1-channel depth, [b, 1, h, w]
+ """
+ if type(pred_logit).__module__ != torch.__name__:
+ pred_logit = torch.tensor(pred_logit, dtype=torch.float32, device="cuda")
+ if type(depth_bins_border).__module__ != torch.__name__:
+ depth_bins_border = torch.tensor(depth_bins_border, dtype=torch.float32, device="cuda")
+
+ pred_score = nn.functional.softmax(pred_logit, dim=1)
+ depth_bins_ch = pred_score.permute(0, 2, 3, 1) #[b, h, w, c] depth = torch.sum(depth, dim=3, dtype=torch.float32, keepdim=True)
+ depth = 10 ** depth
+ depth = depth.permute(0, 3, 1, 2) # [b, 1, h, w]
+
+ confidence, _ = torch.max(pred_logit, dim=1, keepdim=True)
+ return depth, confidence
+
+
+
+if __name__ == '__main__':
+ import numpy as np
+ depth_max = 100
+ depth_min = 0.5
+
+ depth_bin_interval = (np.log10(depth_max) - np.log10(depth_min)) / 200
+ depth_bins_border = [np.log10(depth_min) + depth_bin_interval * (i + 0.5)
+ for i in range(200)]
+
+ sw = SoftWeight(depth_bins_border)
\ No newline at end of file
diff --git a/training/mono/utils/misc.py b/training/mono/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..2947c2cd92cd91f8f5f8e1b692edd49f4bfad58f
--- /dev/null
+++ b/training/mono/utils/misc.py
@@ -0,0 +1,67 @@
+
+
+
+import os
+import torch
+try:
+ from torch._six import inf
+except:
+ from torch import inf
+
+
+class NativeScalerWithGradNormCount:
+ state_dict_key = "amp_scaler"
+
+ def __init__(self):
+ #self._scaler = torch.cuda.amp.GradScaler(init_scale=16384) #init_scale=4096.0
+ self._scaler = torch.cuda.amp.GradScaler(init_scale=1)
+
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
+ self._scaler.scale(loss).backward(create_graph=create_graph)
+ if update_grad:
+ if clip_grad is not None:
+ assert parameters is not None
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
+ try:
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad, error_if_nonfinite=True)
+ except:
+ print('NAN gradient ....')
+ else:
+ raise NotImplementedError
+ self._scaler.unscale_(optimizer)
+ norm = get_grad_norm_(parameters)
+ self._scaler.step(optimizer)
+ self._scaler.update()
+ else:
+ norm = None
+ return True
+ #return norm
+
+ def state_dict(self):
+ return self._scaler.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self._scaler.load_state_dict(state_dict)
+
+def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = [p for p in parameters if p.grad is not None]
+ norm_type = float(norm_type)
+ if len(parameters) == 0:
+ return torch.tensor(0.)
+ device = parameters[0].grad.device
+ if norm_type == inf:
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
+ else:
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
+ return total_norm
+
+def is_bf16_supported():
+ """Returns a bool indicating if the current CUDA device supports dtype bfloat16"""
+ cu_vers = torch.version.cuda
+ if cu_vers is not None:
+ cuda_maj_decide = int(cu_vers.split('.')[0]) >= 11
+ else:
+ cuda_maj_decide = False
+ return torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8 and cuda_maj_decide
\ No newline at end of file
diff --git a/training/mono/utils/pcd_utils.py b/training/mono/utils/pcd_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d409764e35fda2a3fda2c771a4cfdc613e603da
--- /dev/null
+++ b/training/mono/utils/pcd_utils.py
@@ -0,0 +1,52 @@
+import os
+import numpy as np
+from plyfile import PlyData, PlyElement
+
+
+def save_point_cloud(pcd, rgb, filename, binary=True):
+ """Save an RGB point cloud as a PLY file.
+ :paras
+ @pcd: Nx3 matrix, the XYZ coordinates
+ @rgb: NX3 matrix, the rgb colors for each 3D point
+ """
+ assert pcd.shape[0] == rgb.shape[0]
+
+ if rgb is None:
+ gray_concat = np.tile(np.array([128], dtype=np.uint8), (pcd.shape[0], 3))
+ points_3d = np.hstack((pcd, gray_concat))
+ else:
+ points_3d = np.hstack((pcd, rgb))
+ python_types = (float, float, float, int, int, int)
+ npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'),
+ ('blue', 'u1')]
+ if binary is True:
+ # Format into NumPy structured array
+ vertices = []
+ for row_idx in range(points_3d.shape[0]):
+ cur_point = points_3d[row_idx]
+ vertices.append(tuple(dtype(point) for dtype, point in zip(python_types, cur_point)))
+ vertices_array = np.array(vertices, dtype=npy_types)
+ el = PlyElement.describe(vertices_array, 'vertex')
+
+ # Write
+ PlyData([el]).write(filename)
+ else:
+ x = np.squeeze(points_3d[:, 0])
+ y = np.squeeze(points_3d[:, 1])
+ z = np.squeeze(points_3d[:, 2])
+ r = np.squeeze(points_3d[:, 3])
+ g = np.squeeze(points_3d[:, 4])
+ b = np.squeeze(points_3d[:, 5])
+
+ ply_head = 'ply\n' \
+ 'format ascii 1.0\n' \
+ 'element vertex %d\n' \
+ 'property float x\n' \
+ 'property float y\n' \
+ 'property float z\n' \
+ 'property uchar red\n' \
+ 'property uchar green\n' \
+ 'property uchar blue\n' \
+ 'end_header' % r.shape[0]
+ # ---- Save ply data to disk
+ np.savetxt(filename, np.column_stack((x, y, z, r, g, b)), fmt="%d %d %d %d %d %d", header=ply_head, comments='')
diff --git a/training/mono/utils/raindropper/__init__.py b/training/mono/utils/raindropper/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/training/mono/utils/raindropper/config.py b/training/mono/utils/raindropper/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..06b1f211dbcbe37699b96ac2e2bbf60e79c24dda
--- /dev/null
+++ b/training/mono/utils/raindropper/config.py
@@ -0,0 +1,24 @@
+"""
+Arguments:
+maxR -- maximum drop radius
+minR -- minimum drop radius
+maxDrops -- maximum number of drops in the image
+minDrops -- minimum number of drops in the image
+edge_darkratio -- brightness reduction factor for drops edges
+return_label -- flag defining whether a label will be returned or just an image with generated raindrops
+A, B, C, D -- in this code are useless, old version is used for control bezeir
+"""
+
+cfg = {
+ 'maxR': 35, # max not more then 150
+ 'minR': 10,
+ 'maxDrops': 50,
+ 'minDrops': 15,
+ 'edge_darkratio': 1.0,
+ 'return_label': True,
+ 'label_thres': 128,
+ 'A': (1, 4.5),
+ 'B': (3, 1),
+ 'C': (1, 3),
+ 'D': (3, 3)
+}
diff --git a/training/mono/utils/raindropper/dropgenerator.py b/training/mono/utils/raindropper/dropgenerator.py
new file mode 100644
index 0000000000000000000000000000000000000000..39b72c787316520af784655a61962b08b039c6c1
--- /dev/null
+++ b/training/mono/utils/raindropper/dropgenerator.py
@@ -0,0 +1,425 @@
+# change rainy drop func from
+# https://github.com/EvoCargo/RaindropsOnWindshield/blob/main/raindrops_generator/raindrop/dropgenerator.py
+
+import math
+import random
+from random import randint
+
+import cv2
+import numpy as np
+from PIL import Image, ImageDraw, ImageEnhance
+from skimage.measure import label as skimage_label
+
+from .raindrop import Raindrop, make_bezier
+
+
+def CheckCollision(DropList):
+ """This function handle the collision of the drops.
+
+ :param DropList: list of raindrop class objects
+ """
+ listFinalDrops = []
+ Checked_list = []
+ list_len = len(DropList)
+ # because latter raindrops in raindrop list should has more colision information
+ # so reverse list
+ DropList.reverse()
+ drop_key = 1
+ for drop in DropList:
+ # if the drop has not been handle
+ if drop.getKey() not in Checked_list:
+ # if drop has collision with other drops
+ if drop.getIfColli():
+ # get collision list
+ collision_list = drop.getCollisionList()
+ # first get radius and center to decide how will the collision do
+ final_x = drop.getCenters()[0] * drop.getRadius()
+ final_y = drop.getCenters()[1] * drop.getRadius()
+ tmp_devide = drop.getRadius()
+ final_R = drop.getRadius() * drop.getRadius()
+ for col_id in collision_list:
+ col_id = int(col_id)
+ Checked_list.append(col_id)
+ # list start from 0
+ final_x += DropList[list_len - col_id].getRadius() * DropList[list_len - col_id].getCenters()[0]
+ final_y += DropList[list_len - col_id].getRadius() * DropList[list_len - col_id].getCenters()[1]
+ tmp_devide += DropList[list_len - col_id].getRadius()
+ final_R += DropList[list_len - col_id].getRadius() * DropList[list_len - col_id].getRadius()
+ final_x = int(round(final_x / tmp_devide))
+ final_y = int(round(final_y / tmp_devide))
+ final_R = int(round(math.sqrt(final_R)))
+ # rebuild drop after handled the collisions
+ newDrop = Raindrop(drop_key, (final_x, final_y), final_R)
+ drop_key = drop_key + 1
+ listFinalDrops.append(newDrop)
+ # no collision
+ else:
+ drop.setKey(drop_key)
+ drop_key = drop_key + 1
+ listFinalDrops.append(drop)
+
+ return listFinalDrops
+
+
+def generate_label(h, w, cfg):
+ """This function generate list of raindrop class objects and label map of
+ this drops in the image.
+
+ :param h: image height
+ :param w: image width
+ :param cfg: config with global constants
+ :param shape: int from 0 to 2 defining raindrop shape type
+ """
+ maxDrop = cfg['maxDrops']
+ minDrop = cfg['minDrops']
+ maxR = cfg['maxR']
+ minR = cfg['minR']
+ drop_num = randint(minDrop, maxDrop)
+ imgh = h
+ imgw = w
+ # random drops position
+ ran_pos = [(int(random.random() * imgw), int(random.random() * imgh)) for _ in range(drop_num)]
+ listRainDrops = []
+ listFinalDrops = []
+ for key, pos in enumerate(ran_pos):
+ key = key + 1
+ radius = random.randint(minR, maxR)
+ shape = random.randint(1, 1)
+ drop = Raindrop(key, pos, radius, shape)
+ listRainDrops.append(drop)
+# to check if collision or not
+ label_map = np.zeros([h, w])
+ collisionNum = len(listRainDrops)
+ listFinalDrops = list(listRainDrops)
+ loop = 0
+ while collisionNum > 0:
+ loop = loop + 1
+ listFinalDrops = list(listFinalDrops)
+ collisionNum = len(listFinalDrops)
+ label_map = np.zeros_like(label_map)
+ # Check Collision
+ for drop in listFinalDrops:
+ # check the bounding
+ (ix, iy) = drop.getCenters()
+ radius = drop.getRadius()
+ ROI_WL = 2 * radius
+ ROI_WR = 2 * radius
+ ROI_HU = 3 * radius
+ ROI_HD = 2 * radius
+ if (iy - 3 * radius) < 0:
+ ROI_HU = iy
+ if (iy + 2 * radius) > imgh:
+ ROI_HD = imgh - iy
+ if (ix - 2 * radius) < 0:
+ ROI_WL = ix
+ if (ix + 2 * radius) > imgw:
+ ROI_WR = imgw - ix
+
+
+# apply raindrop label map to Image's label map
+ drop_label = drop.getLabelMap()
+ # check if center has already has drops
+ if (label_map[iy, ix] > 0):
+ col_ids = np.unique(label_map[iy - ROI_HU:iy + ROI_HD, ix - ROI_WL:ix + ROI_WR])
+ col_ids = col_ids[col_ids != 0]
+ drop.setCollision(True, col_ids)
+ label_map[iy - ROI_HU:iy + ROI_HD,
+ ix - ROI_WL:ix + ROI_WR] = drop_label[3 * radius - ROI_HU:3 * radius + ROI_HD, 2 * radius -
+ ROI_WL:2 * radius + ROI_WR] * drop.getKey()
+ else:
+ label_map[iy - ROI_HU:iy + ROI_HD,
+ ix - ROI_WL:ix + ROI_WR] = drop_label[3 * radius - ROI_HU:3 * radius + ROI_HD, 2 * radius -
+ ROI_WL:2 * radius + ROI_WR] * drop.getKey()
+ # no collision
+ collisionNum = collisionNum - 1
+
+ if collisionNum > 0:
+ listFinalDrops = CheckCollision(listFinalDrops)
+ return listFinalDrops, label_map
+
+
+def generateDrops(imagePath, cfg, listFinalDrops):
+ """Generate raindrops on the image.
+
+ :param imagePath: path to the image on which you want to generate drops
+ :param cfg: config with global constants
+ :param listFinalDrops: final list of raindrop class objects after handling collisions
+ :param label_map: general label map of all drops in the image
+ """
+ ifReturnLabel = cfg['return_label']
+ edge_ratio = cfg['edge_darkratio']
+
+ PIL_bg_img = Image.open(imagePath).convert('RGB')
+ bg_img = np.asarray(PIL_bg_img)
+ label_map = np.zeros_like(bg_img)[:, :, 0]
+ imgh, imgw, _ = bg_img.shape
+
+ A = cfg['A']
+ B = cfg['B']
+ C = cfg['C']
+ D = cfg['D']
+
+ alpha_map = np.zeros_like(label_map).astype(np.float64)
+
+ for drop in listFinalDrops:
+ (ix, iy) = drop.getCenters()
+ radius = drop.getRadius()
+ ROI_WL = 2 * radius
+ ROI_WR = 2 * radius
+ ROI_HU = 3 * radius
+ ROI_HD = 2 * radius
+ if (iy - 3 * radius) < 0:
+ ROI_HU = iy
+ if (iy + 2 * radius) > imgh:
+ ROI_HD = imgh - iy
+ if (ix - 2 * radius) < 0:
+ ROI_WL = ix
+ if (ix + 2 * radius) > imgw:
+ ROI_WR = imgw - ix
+
+ drop_alpha = drop.getAlphaMap()
+ alpha_map[iy - ROI_HU:iy + ROI_HD,
+ ix - ROI_WL:ix + ROI_WR] += drop_alpha[3 * radius - ROI_HU:3 * radius + ROI_HD,
+ 2 * radius - ROI_WL:2 * radius + ROI_WR]
+
+ alpha_map = alpha_map / np.max(alpha_map) * 255.0
+
+ PIL_bg_img = Image.open(imagePath)
+ for idx, drop in enumerate(listFinalDrops):
+ (ix, iy) = drop.getCenters()
+ radius = drop.getRadius()
+ ROIU = iy - 3 * radius
+ ROID = iy + 2 * radius
+ ROIL = ix - 2 * radius
+ ROIR = ix + 2 * radius
+ if (iy - 3 * radius) < 0:
+ ROIU = 0
+ ROID = 5 * radius
+ if (iy + 2 * radius) > imgh:
+ ROIU = imgh - 5 * radius
+ ROID = imgh
+ if (ix - 2 * radius) < 0:
+ ROIL = 0
+ ROIR = 4 * radius
+ if (ix + 2 * radius) > imgw:
+ ROIL = imgw - 4 * radius
+ ROIR = imgw
+
+ tmp_bg = bg_img[ROIU:ROID, ROIL:ROIR, :]
+ try:
+ drop.updateTexture(tmp_bg)
+ except Exception:
+ del listFinalDrops[idx]
+ continue
+ tmp_alpha_map = alpha_map[ROIU:ROID, ROIL:ROIR]
+
+ output = drop.getTexture()
+ tmp_output = np.asarray(output).astype(np.float)[:, :, -1]
+ tmp_alpha_map = tmp_alpha_map * (tmp_output / 255)
+ tmp_alpha_map = Image.fromarray(tmp_alpha_map.astype('uint8'))
+
+ edge = ImageEnhance.Brightness(output)
+ edge = edge.enhance(edge_ratio)
+
+ PIL_bg_img.paste(edge, (ix - 2 * radius, iy - 3 * radius), output)
+ PIL_bg_img.paste(output, (ix - 2 * radius, iy - 3 * radius), output)
+
+ mask = np.zeros_like(bg_img)
+
+ if len(listFinalDrops) > 0:
+ # make circles and elipses
+ for drop in listFinalDrops:
+ if (drop.shape == 0):
+ cv2.circle(mask, drop.center, drop.radius, (255, 255, 255), -1)
+ if (drop.shape == 1):
+ cv2.circle(mask, drop.center, drop.radius, (255, 255, 255), -1)
+ cv2.ellipse(mask, drop.center, (drop.radius, int(1.3 * math.sqrt(3) * drop.radius)), 0, 180, 360,
+ (255, 255, 255), -1)
+
+ img = Image.fromarray(np.uint8(mask[:, :, 0]), 'L')
+ # make beziers
+ for drop in listFinalDrops:
+ if (drop.shape == 2):
+ img = Image.fromarray(np.uint8(img), 'L')
+ draw = ImageDraw.Draw(img)
+ ts = [t / 100.0 for t in range(101)]
+ xys = [(drop.radius * C[0] - 2 * drop.radius + drop.center[0],
+ drop.radius * C[1] - 3 * drop.radius + drop.center[1]),
+ (drop.radius * B[0] - 2 * drop.radius + drop.center[0],
+ drop.radius * B[1] - 3 * drop.radius + drop.center[1]),
+ (drop.radius * D[0] - 2 * drop.radius + drop.center[0],
+ drop.radius * D[1] - 3 * drop.radius + drop.center[1])]
+ bezier = make_bezier(xys)
+ points = bezier(ts)
+
+ xys = [(drop.radius * C[0] - 2 * drop.radius + drop.center[0],
+ drop.radius * C[1] - 3 * drop.radius + drop.center[1]),
+ (drop.radius * A[0] - 2 * drop.radius + drop.center[0],
+ drop.radius * A[1] - 3 * drop.radius + drop.center[1]),
+ (drop.radius * D[0] - 2 * drop.radius + drop.center[0],
+ drop.radius * D[1] - 3 * drop.radius + drop.center[1])]
+ bezier = make_bezier(xys)
+ points.extend(bezier(ts))
+ draw.polygon(points, fill='white')
+ mask = np.array(img)
+
+ im_mask = Image.fromarray(mask.astype('uint8'))
+
+ if ifReturnLabel:
+ output_label = np.array(alpha_map)
+ output_label.flags.writeable = True
+ output_label[output_label > 0] = 1
+ output_label = Image.fromarray(output_label.astype('uint8'))
+ return PIL_bg_img, output_label, im_mask
+
+ return PIL_bg_img
+
+
+def generateDrops_np(img_np, cfg, listFinalDrops):
+ """Generate raindrops on the image.
+
+ :param imgs: numpy imgs shape -> [B, H, W, C], type -> np.uint8
+ :param cfg: config with global constants
+ :param listFinalDrops: final list of raindrop class objects after handling collisions
+ :param label_map: general label map of all drops in the image
+ """
+ ifReturnLabel = cfg['return_label']
+ edge_ratio = cfg['edge_darkratio']
+
+ # PIL_bg_img = Image.open(imagePath)
+ # label_map = np.zeros_like(bg_img)[:,:,0]
+ # imgh, imgw, _ = bg_img.shape
+ bg_img = img_np
+ label_map = np.zeros_like(bg_img)[:, :, 0] # [H, W]
+ imgh, imgw, _ = bg_img.shape
+
+ A = cfg['A']
+ B = cfg['B']
+ C = cfg['C']
+ D = cfg['D']
+
+ # 0. generate alpha change map by generated list raindrops
+ alpha_map = np.zeros_like(label_map).astype(np.float64) # [H, W]
+
+ for drop in listFinalDrops:
+ (ix, iy) = drop.getCenters()
+ radius = drop.getRadius()
+ ROI_WL = 2 * radius
+ ROI_WR = 2 * radius
+ ROI_HU = 3 * radius
+ ROI_HD = 2 * radius
+ if (iy - 3 * radius) < 0:
+ ROI_HU = iy
+ if (iy + 2 * radius) > imgh:
+ ROI_HD = imgh - iy
+ if (ix - 2 * radius) < 0:
+ ROI_WL = ix
+ if (ix + 2 * radius) > imgw:
+ ROI_WR = imgw - ix
+
+ drop_alpha = drop.getAlphaMap()
+
+ alpha_map[iy - ROI_HU:iy + ROI_HD,
+ ix - ROI_WL:ix + ROI_WR] += drop_alpha[3 * radius - ROI_HU:3 * radius + ROI_HD,
+ 2 * radius - ROI_WL:2 * radius + ROI_WR]
+
+ alpha_map = alpha_map / np.max(alpha_map) * 255.0
+
+ PIL_bg_img = Image.fromarray(np.uint8(bg_img)).convert('RGB')
+ # convert
+ for idx, drop in enumerate(listFinalDrops):
+ (ix, iy) = drop.getCenters()
+ radius = drop.getRadius()
+ ROIU = iy - 3 * radius
+ ROID = iy + 2 * radius
+ ROIL = ix - 2 * radius
+ ROIR = ix + 2 * radius
+ if (iy - 3 * radius) < 0:
+ ROIU = 0
+ ROID = 5 * radius
+ if (iy + 2 * radius) > imgh:
+ ROIU = imgh - 5 * radius
+ ROID = imgh
+ if (ix - 2 * radius) < 0:
+ ROIL = 0
+ ROIR = 4 * radius
+ if (ix + 2 * radius) > imgw:
+ ROIL = imgw - 4 * radius
+ ROIR = imgw
+
+ tmp_bg = bg_img[ROIU:ROID, ROIL:ROIR]
+ try:
+ drop.updateTexture(tmp_bg)
+ except Exception:
+ del listFinalDrops[idx]
+ continue
+ tmp_alpha_map = alpha_map[ROIU:ROID, ROIL:ROIR]
+
+ output = drop.getTexture()
+ tmp_output = np.asarray(output).astype(np.float)[:, :, -1]
+ tmp_alpha_map = tmp_alpha_map * (tmp_output / 255)
+ tmp_alpha_map = Image.fromarray(tmp_alpha_map.astype('uint8'))
+
+ edge = ImageEnhance.Brightness(output)
+ edge = edge.enhance(edge_ratio)
+
+ # PIL_bg_img.paste(edge, (ix-2*radius, iy-3*radius), output)
+ # PIL_bg_img.paste(output, (ix-2*radius, iy-3*radius), output)
+ PIL_bg_img.paste(edge, (ROIL, ROIU), output)
+ PIL_bg_img.paste(output, (ROIL, ROIU), output)
+
+
+# mask process part
+ mask = np.zeros_like(bg_img)
+
+ if len(listFinalDrops) > 0:
+ # make circles and elipses
+ for drop in listFinalDrops:
+ if (drop.shape == 0):
+ cv2.circle(mask, drop.center, drop.radius, (255, 255, 255), -1)
+ if (drop.shape == 1):
+ cv2.circle(mask, drop.center, drop.radius, (255, 255, 255), -1)
+ cv2.ellipse(mask, drop.center, (drop.radius, int(1.3 * math.sqrt(3) * drop.radius)), 0, 180, 360,
+ (255, 255, 255), -1)
+
+ img = Image.fromarray(np.uint8(mask[:, :, 0]), 'L')
+ # make beziers
+ for drop in listFinalDrops:
+ if (drop.shape == 2):
+ img = Image.fromarray(np.uint8(img), 'L')
+ draw = ImageDraw.Draw(img)
+ ts = [t / 100.0 for t in range(101)]
+ A0, A1 = drop.control_point['A'][0], drop.control_point['A'][1]
+ B0, B1 = drop.control_point['B'][0], drop.control_point['B'][1]
+ C0, C1 = drop.control_point['C'][0], drop.control_point['C'][1]
+ D0, D1 = drop.control_point['D'][0], drop.control_point['D'][1]
+ xys = [(drop.radius * C0 - 2 * drop.radius + drop.center[0],
+ drop.radius * C1 - 3 * drop.radius + drop.center[1]),
+ (drop.radius * B0 - 2 * drop.radius + drop.center[0],
+ drop.radius * B1 - 3 * drop.radius + drop.center[1]),
+ (drop.radius * D0 - 2 * drop.radius + drop.center[0],
+ drop.radius * D1 - 3 * drop.radius + drop.center[1])]
+ bezier = make_bezier(xys)
+ points = bezier(ts)
+
+ xys = [(drop.radius * C0 - 2 * drop.radius + drop.center[0],
+ drop.radius * C1 - 3 * drop.radius + drop.center[1]),
+ (drop.radius * A0 - 2 * drop.radius + drop.center[0],
+ drop.radius * A1 - 3 * drop.radius + drop.center[1]),
+ (drop.radius * D0 - 2 * drop.radius + drop.center[0],
+ drop.radius * D1 - 3 * drop.radius + drop.center[1])]
+ bezier = make_bezier(xys)
+ points.extend(bezier(ts))
+ draw.polygon(points, fill='white')
+ mask = np.array(img)
+
+ im_mask = Image.fromarray(mask.astype('uint8'))
+
+ if ifReturnLabel:
+ output_label = np.array(alpha_map)
+ output_label.flags.writeable = True
+ output_label[output_label > 0] = 1
+ output_label = Image.fromarray(output_label.astype('uint8'))
+ return PIL_bg_img, output_label, im_mask
+
+ return PIL_bg_img
diff --git a/training/mono/utils/raindropper/raindrop.py b/training/mono/utils/raindropper/raindrop.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae7b66c1f5f1a6280b8898b06206131c8a6e289f
--- /dev/null
+++ b/training/mono/utils/raindropper/raindrop.py
@@ -0,0 +1,194 @@
+# change rainy drop func from
+# https://github.com/EvoCargo/RaindropsOnWindshield/blob/main/raindrops_generator/raindrop/raindrop.py
+
+import math
+import random
+from random import randint
+
+import cv2
+import numpy as np
+from PIL import Image, ImageDraw, ImageFilter
+from raindropper.config import cfg
+
+
+def make_bezier(xys):
+ # xys should be a sequence of 2-tuples (Bezier control points)
+ n = len(xys)
+ combinations = pascal_row(n - 1)
+
+ def bezier(ts):
+ # This uses the generalized formula for bezier curves
+ # http://en.wikipedia.org/wiki/B%C3%A9zier_curve#Generalization
+ result = []
+ for t in ts:
+ tpowers = (t**i for i in range(n))
+ upowers = reversed([(1 - t)**i for i in range(n)])
+ coefs = [c * a * b for c, a, b in zip(combinations, tpowers, upowers)]
+ result.append(tuple(sum([coef * p for coef, p in zip(coefs, ps)]) for ps in zip(*xys)))
+ return result
+
+ return bezier
+
+
+def pascal_row(n, memo={}):
+ # This returns the nth row of Pascal Triangle
+ if n in memo:
+ return memo[n]
+ result = [1]
+ x, numerator = 1, n
+ for denominator in range(1, n // 2 + 1):
+ x *= numerator
+ x /= denominator
+ result.append(x)
+ numerator -= 1
+ if n & 1 == 0:
+ result.extend(reversed(result[:-1]))
+ else:
+ result.extend(reversed(result))
+ memo[n] = result
+ return result
+
+
+class Raindrop():
+
+ def __init__(self, key, centerxy=None, radius=None, shape=None):
+ # param key: a unique key identifying a drop
+ # param centerxy: tuple defining coordinates of raindrop center in the image
+ # param radius: radius of a drop
+ # param shape: int from 0 to 2 defining raindrop shape type
+ self.key = key
+ self.ifcol = False
+ self.col_with = []
+ self.center = centerxy
+ self.radius = radius
+ # self.blur_coeff = max(int(self.radius/3), 1)
+ # self.blur_coeff = max(int(cfg["maxR"] / self.radius), 1)
+ self.blur_coeff = 3
+ self.shape = shape
+ self.type = 'default'
+ # label map's WxH = 4*R , 5*R
+ self.labelmap = np.zeros((self.radius * 5, self.radius * 4))
+ self.alphamap = np.zeros((self.radius * 5, self.radius * 4))
+ self.background = None
+ self.texture = None
+ self.control_point = {}
+ self._create_label()
+ self.use_label = False
+
+ def setCollision(self, col, col_with):
+ self.ifcol = col
+ self.col_with = col_with
+
+ def updateTexture(self, bg):
+ # gaussian blur radius may be 1, 3, 5
+ radius_array = [1, 3]
+ blur_radius_idx = randint(0, 1)
+ blur_radius = radius_array[blur_radius_idx]
+ fg = (Image.fromarray(np.uint8(bg))).filter(ImageFilter.GaussianBlur(radius=blur_radius))
+ fg = np.asarray(fg)
+
+ # add fish eye effect to simulate the background
+ K = np.array([[30 * self.radius, 0, 2 * self.radius], [0., 20 * self.radius, 3 * self.radius], [0., 0., 1]])
+ D = np.array([0.0, 0.0, 0.0, 0.0])
+ Knew = K.copy()
+
+ Knew[(0, 1), (0, 1)] = math.pow(self.radius, 1 / 500) * 2 * Knew[(0, 1), (0, 1)]
+ fisheye = cv2.fisheye.undistortImage(fg, K, D=D, Knew=Knew)
+
+ tmp = np.expand_dims(self.alphamap, axis=-1)
+ tmp = np.concatenate((fisheye, tmp), axis=2)
+
+ self.texture = Image.fromarray(tmp.astype('uint8'), 'RGBA')
+
+ def _create_label(self):
+ self._createDefaultDrop()
+
+ def _createDefaultDrop(self):
+ """create the raindrop Alpha Map according to its shape type update
+ raindrop label."""
+ if (self.shape == 0):
+ cv2.circle(self.labelmap, (self.radius * 2, self.radius * 3), int(self.radius), 128, -1)
+ self.alphamap = (Image.fromarray(np.uint8(self.labelmap))).filter(
+ ImageFilter.GaussianBlur(radius=self.blur_coeff))
+ self.alphamap = np.asarray(self.alphamap).astype(np.float)
+ self.alphamap = self.alphamap / np.max(self.alphamap) * 255.0
+ # set label map
+ self.labelmap[self.labelmap > 0] = 1
+
+ if (self.shape == 1):
+ cv2.circle(self.labelmap, (self.radius * 2, self.radius * 3), int(self.radius), 128, -1)
+ cv2.ellipse(self.labelmap, (self.radius * 2, self.radius * 3),
+ (self.radius, int(1.3 * math.sqrt(3) * self.radius)), 0, 180, 360, 128, -1)
+
+ self.alphamap = (Image.fromarray(np.uint8(self.labelmap))).filter(
+ ImageFilter.GaussianBlur(radius=self.blur_coeff))
+ self.alphamap = np.asarray(self.alphamap).astype(np.float)
+ self.alphamap = self.alphamap / np.max(self.alphamap) * 255.0
+ # set label map
+ self.labelmap[self.labelmap > 0] = 1
+
+ if (self.shape == 2):
+ C0 = random.uniform(0, 1)
+ C1 = random.uniform(0, 1)
+ A0 = random.uniform(0, 1)
+ A1 = random.uniform(2, 3)
+ D0 = random.uniform(2, 3)
+ D1 = random.uniform(2, 3)
+ B0 = random.uniform(2, 3)
+ B1 = random.uniform(0, 1)
+
+ self.control_point['A'] = (A0, A1)
+ self.control_point['B'] = (B0, B1)
+ self.control_point['C'] = (C0, C1)
+ self.control_point['D'] = (D0, D1)
+
+ img = Image.fromarray(np.uint8(self.labelmap), 'L')
+ draw = ImageDraw.Draw(img)
+ ts = [t / 100.0 for t in range(101)]
+ xys = [(self.radius * C0, self.radius * C1), (self.radius * B0, self.radius * B1),
+ (self.radius * D0, self.radius * D1)]
+ bezier = make_bezier(xys)
+ points = bezier(ts)
+
+ xys = [(self.radius * C0, self.radius * C1), (self.radius * A0, self.radius * A1),
+ (self.radius * D0, self.radius * D1)]
+ bezier = make_bezier(xys)
+ points.extend(bezier(ts))
+ draw.polygon(points, fill='gray')
+
+ self.alphamap = img.filter(ImageFilter.GaussianBlur(radius=self.blur_coeff))
+ self.alphamap = np.asarray(self.alphamap).astype(np.float)
+ self.alphamap = self.alphamap / np.max(self.alphamap) * 255.0
+
+ # set label map
+ self.labelmap[self.labelmap > 0] = 1
+
+ def setKey(self, key):
+ self.key = key
+
+ def getLabelMap(self):
+ return self.labelmap
+
+ def getAlphaMap(self):
+ return self.alphamap
+
+ def getTexture(self):
+ return self.texture
+
+ def getCenters(self):
+ return self.center
+
+ def getRadius(self):
+ return self.radius
+
+ def getKey(self):
+ return self.key
+
+ def getIfColli(self):
+ return self.ifcol
+
+ def getCollisionList(self):
+ return self.col_with
+
+ def getUseLabel(self):
+ return self.use_label
diff --git a/training/mono/utils/raindropper/raindrop_augmentor.py b/training/mono/utils/raindropper/raindrop_augmentor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c160359cee94128cb644ee886231d3d45367fdb0
--- /dev/null
+++ b/training/mono/utils/raindropper/raindrop_augmentor.py
@@ -0,0 +1,30 @@
+import numpy as np
+
+from .config import cfg
+from .dropgenerator import generate_label, generateDrops_np
+
+
+class RainDrop_Augmentor():
+
+ def __init__(self, height, width):
+ drops_list, label_map = generate_label(height, width, cfg)
+ self.drops_list = drops_list
+ self.label_map = label_map
+
+ def generate_one(self, img_np, mode='rgb'):
+
+ assert mode in ['gray', 'rgb']
+
+ # requirement input [H, W, 3]
+ if (mode == 'gray'):
+ img_np = np.squeeze(img_np)
+ img_np = np.expand_dims(img_np, axis=-1)
+ img_np = np.repeat(img_np, 3, axis=-1)
+
+ output_img, output_label, mask = generateDrops_np(img_np, cfg, self.drops_list)
+ output_img = np.asarray(output_img)
+
+ if (mode == 'gray'):
+ output_img = output_img[:, :, 0]
+
+ return output_img
diff --git a/training/mono/utils/running.py b/training/mono/utils/running.py
new file mode 100644
index 0000000000000000000000000000000000000000..df0efc0b29b8c72411fd27bb374c2782f513c51f
--- /dev/null
+++ b/training/mono/utils/running.py
@@ -0,0 +1,374 @@
+import os
+import torch
+import torch.nn as nn
+from mono.utils.comm import main_process
+import copy
+import inspect
+import logging
+import glob
+
+class LrUpdater():
+ """Refer to LR Scheduler in MMCV.
+ Args:
+ @by_epoch (bool): LR changes epoch by epoch
+ @warmup (string): Type of warmup used. It can be None(use no warmup),
+ 'constant', 'linear' or 'exp'
+ @warmup_iters (int): The number of iterations or epochs that warmup
+ lasts. Note when by_epoch == True, warmup_iters means the number
+ of epochs that warmup lasts, otherwise means the number of
+ iteration that warmup lasts
+ @warmup_ratio (float): LR used at the beginning of warmup equals to
+ warmup_ratio * initial_lr
+ @runner (dict): Configs for running. Run by epoches or iters.
+ """
+
+ def __init__(self,
+ by_epoch: bool=True,
+ warmup: str=None,
+ warmup_iters: int=0,
+ warmup_ratio: float=0.1,
+ runner: dict={}):
+ # validate the "warmup" argument
+ if warmup is not None:
+ if warmup not in ['constant', 'linear', 'exp']:
+ raise ValueError(
+ f'"{warmup}" is not a supported type for warming up, valid'
+ ' types are "constant" and "linear"')
+ if warmup is not None:
+ assert warmup_iters > 0, \
+ '"warmup_iters" must be a positive integer'
+ assert 0 < warmup_ratio <= 1.0, \
+ '"warmup_ratio" must be in range (0,1]'
+
+ if runner is None:
+ raise RuntimeError('runner should be set.')
+
+ self.by_epoch = by_epoch
+ self.warmup = warmup
+ self.warmup_iters = warmup_iters
+ self.warmup_ratio = warmup_ratio
+ self.runner = runner
+
+ self.max_iters = None
+ self.max_epoches = None
+ if 'IterBasedRunner' in self.runner.type:
+ self.max_iters = self.runner.max_iters
+ assert self.by_epoch==False
+ self.warmup_by_epoch = False
+ elif 'EpochBasedRunner' in self.runner.type:
+ self.max_epoches = self.runner.max_epoches
+ assert self.by_epoch==True
+ self.warmup_by_epoch = True
+ else:
+ raise ValueError(f'{self.runner.type} is not a supported type for running.')
+
+ if self.warmup_by_epoch:
+ self.warmup_epochs = self.warmup_iters
+ self.warmup_iters = None
+ else:
+ self.warmup_epochs = None
+
+ self.base_lr = [] # initial lr for all param groups
+ self.regular_lr = [] # expected lr if no warming up is performed
+ self._step_count = 0
+
+ def _set_lr(self, optimizer, lr_groups):
+ if isinstance(optimizer, dict):
+ for k, optim in optimizer.items():
+ for param_group, lr in zip(optim.param_groups, lr_groups[k]):
+ param_group['lr'] = lr
+ else:
+ for param_group, lr in zip(optimizer.param_groups,
+ lr_groups):
+ param_group['lr'] = lr
+
+ def get_lr(self, _iter, max_iter, base_lr):
+ raise NotImplementedError
+
+ def get_regular_lr(self, _iter, optimizer):
+ max_iters = self.max_iters if not self.by_epoch else self.max_epoches
+
+ if isinstance(optimizer, dict):
+ lr_groups = {}
+ for k in optimizer.keys():
+ _lr_group = [
+ self.get_lr(_iter, max_iters, _base_lr)
+ for _base_lr in self.base_lr[k]
+ ]
+ lr_groups.update({k: _lr_group})
+
+ return lr_groups
+ else:
+ return [self.get_lr(_iter, max_iters, _base_lr) for _base_lr in self.base_lr]
+
+ def get_warmup_lr(self, cur_iters):
+
+ def _get_warmup_lr(cur_iters, regular_lr):
+ if self.warmup == 'constant':
+ warmup_lr = [_lr * self.warmup_ratio for _lr in regular_lr]
+ elif self.warmup == 'linear':
+ k = (1 - cur_iters / self.warmup_iters) * (1 -
+ self.warmup_ratio)
+ warmup_lr = [_lr * (1 - k) for _lr in regular_lr]
+ elif self.warmup == 'exp':
+ k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters)
+ warmup_lr = [_lr * k for _lr in regular_lr]
+ return warmup_lr
+
+ if isinstance(self.regular_lr, dict):
+ lr_groups = {}
+ for key, regular_lr in self.regular_lr.items():
+ lr_groups[key] = _get_warmup_lr(cur_iters, regular_lr)
+ return lr_groups
+ else:
+ return _get_warmup_lr(cur_iters, self.regular_lr)
+
+ def before_run(self, optimizer):
+ # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
+ # it will be set according to the optimizer params
+ if isinstance(optimizer, dict):
+ self.base_lr = {}
+ for k, optim in optimizer.items():
+ for group in optim.param_groups:
+ group.setdefault('initial_lr', group['lr'])
+ _base_lr = [
+ group['initial_lr'] for group in optim.param_groups
+ ]
+ self.base_lr.update({k: _base_lr})
+ else:
+ for group in optimizer.param_groups:
+ group.setdefault('initial_lr', group['lr'])
+ self.base_lr = [
+ group['initial_lr'] for group in optimizer.param_groups
+ ]
+
+ def after_train_epoch(self, optimizer):
+ self._step_count += 1
+ curr_epoch = self._step_count
+ self.regular_lr = self.get_regular_lr(curr_epoch, optimizer)
+ if self.warmup is None or curr_epoch > self.warmup_epoches:
+ self._set_lr(optimizer, self.regular_lr)
+ else:
+ #self.warmup_iters = int(self.warmup_epochs * epoch_len)
+ warmup_lr = self.get_warmup_lr(curr_epoch)
+ self._set_lr(optimizer, warmup_lr)
+
+ def after_train_iter(self, optimizer):
+ self._step_count += 1
+ cur_iter = self._step_count
+ self.regular_lr = self.get_regular_lr(cur_iter, optimizer)
+ if self.warmup is None or cur_iter >= self.warmup_iters:
+ self._set_lr(optimizer, self.regular_lr)
+ else:
+ warmup_lr = self.get_warmup_lr(cur_iter)
+ self._set_lr(optimizer, warmup_lr)
+
+ def get_curr_lr(self, cur_iter):
+ if self.warmup is None or cur_iter >= self.warmup_iters:
+ return self.regular_lr
+ else:
+ return self.get_warmup_lr(cur_iter)
+
+ def state_dict(self):
+ """
+ Returns the state of the scheduler as a :class:`dict`.
+ It contains an entry for every variable in self.__dict__ which
+ is not the optimizer.
+ """
+ return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
+
+ def load_state_dict(self, state_dict):
+ """Loads the schedulers state.
+
+ Args:
+ @state_dict (dict): scheduler state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ self.__dict__.update(state_dict)
+
+
+class PolyLrUpdater(LrUpdater):
+
+ def __init__(self, power=1., min_lr=0., **kwargs):
+ self.power = power
+ self.min_lr = min_lr
+ super(PolyLrUpdater, self).__init__(**kwargs)
+
+ def get_lr(self, _iter, max_iters, base_lr):
+ progress = _iter
+ max_progress = max_iters
+ coeff = (1 - progress / max_progress)**self.power
+ return (base_lr - self.min_lr) * coeff + self.min_lr
+
+
+def build_lr_schedule_with_cfg(cfg):
+ # build learning rate schedule with config.
+ lr_config = copy.deepcopy(cfg.lr_config)
+ policy = lr_config.pop('policy')
+ if cfg.lr_config.policy == 'poly':
+ schedule = PolyLrUpdater(runner=cfg.runner, **lr_config)
+ else:
+ raise RuntimeError(f'{cfg.lr_config.policy} \
+ is not supported in this framework.')
+ return schedule
+
+
+#def step_learning_rate(base_lr, epoch, step_epoch, multiplier=0.1):
+# """Sets the learning rate to the base LR decayed by 10 every step epochs"""
+# lr = base_lr * (multiplier ** (epoch // step_epoch))
+# return lr
+
+def register_torch_optimizers():
+ torch_optimizers = {}
+ for module_name in dir(torch.optim):
+ if module_name.startswith('__'):
+ continue
+ _optim = getattr(torch.optim, module_name)
+ if inspect.isclass(_optim) and issubclass(_optim,
+ torch.optim.Optimizer):
+ torch_optimizers[module_name] = _optim
+ return torch_optimizers
+
+
+TORCH_OPTIMIZER = register_torch_optimizers()
+
+def build_optimizer_with_cfg(cfg, model):
+ # encoder_parameters = []
+ # decoder_parameters = []
+ # nongrad_parameters = []
+ # for key, value in dict(model.named_parameters()).items():
+ # if value.requires_grad:
+ # if 'encoder' in key:
+ # encoder_parameters.append(value)
+ # else:
+ # decoder_parameters.append(value)
+ # else:
+ # nongrad_parameters.append(value)
+
+ #params = [{"params": filter(lambda p: p.requires_grad, model.parameters())}]
+ optim_cfg = copy.deepcopy(cfg.optimizer)
+ optim_type = optim_cfg.pop('type', None)
+
+ if optim_type is None:
+ raise RuntimeError(f'{optim_type} is not set')
+ if optim_type not in TORCH_OPTIMIZER:
+ raise RuntimeError(f'{optim_type} is not supported in torch {torch.__version__}')
+ if 'others' not in optim_cfg:
+ optim_cfg['others'] = optim_cfg['decoder']
+
+ def match(key1, key_list, strict_match=False):
+ if not strict_match:
+ for k in key_list:
+ if k in key1:
+ return k
+ else:
+ for k in key_list:
+ if k == key1.split('.')[1]:
+ return k
+ return None
+ optim_obj = TORCH_OPTIMIZER[optim_type]
+ matching_type = optim_cfg.pop('strict_match', False)
+
+ module_names = optim_cfg.keys()
+ model_parameters = {i: [] for i in module_names}
+ model_parameters['others'] = []
+ nongrad_parameters = []
+ for key, value in dict(model.named_parameters()).items():
+ if value.requires_grad:
+ match_key = match(key, module_names, matching_type)
+ # if optim_cfg[match_key]['lr'] == 0:
+ # value.requires_grad=False
+ # continue
+ if match_key is None:
+ model_parameters['others'].append(value)
+ else:
+ model_parameters[match_key].append(value)
+ else:
+ nongrad_parameters.append(value)
+
+ optims = [{'params':model_parameters[k], **optim_cfg[k]} for k in optim_cfg.keys()]
+ optimizer = optim_obj(optims)
+ # optim_args_encoder = optim_cfg.optimizer.encoder
+ # optim_args_decoder = optim_cfg.optimizer.decoder
+ # optimizer = optim_obj(
+ # [{'params':encoder_parameters, **optim_args_encoder},
+ # {'params':decoder_parameters, **optim_args_decoder},
+ # ])
+
+ return optimizer
+
+
+def load_ckpt(load_path, model, optimizer=None, scheduler=None, strict_match=True, loss_scaler=None):
+ """
+ Load the check point for resuming training or finetuning.
+ """
+ logger = logging.getLogger()
+ if os.path.isfile(load_path):
+ if main_process():
+ logger.info(f"Loading weight '{load_path}'")
+ checkpoint = torch.load(load_path, map_location="cpu")
+ ckpt_state_dict = checkpoint['model_state_dict']
+ model.module.load_state_dict(ckpt_state_dict, strict=strict_match)
+
+ if optimizer is not None:
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ if scheduler is not None:
+ scheduler.load_state_dict(checkpoint['scheduler'])
+ if loss_scaler is not None and 'scaler' in checkpoint:
+ loss_scaler.load_state_dict(checkpoint['scaler'])
+ print('Loss scaler loaded', loss_scaler)
+ del ckpt_state_dict
+ del checkpoint
+ if main_process():
+ logger.info(f"Successfully loaded weight: '{load_path}'")
+ if scheduler is not None and optimizer is not None:
+ logger.info(f"Resume training from: '{load_path}'")
+ else:
+ if main_process():
+ raise RuntimeError(f"No weight found at '{load_path}'")
+ return model, optimizer, scheduler, loss_scaler
+
+
+def save_ckpt(cfg, model, optimizer, scheduler, curr_iter=0, curr_epoch=None, loss_scaler=None):
+ """
+ Save the model, optimizer, lr scheduler.
+ """
+ logger = logging.getLogger()
+
+ if 'IterBasedRunner' in cfg.runner.type:
+ max_iters = cfg.runner.max_iters
+ elif 'EpochBasedRunner' in cfg.runner.type:
+ max_iters = cfg.runner.max_epoches
+ else:
+ raise TypeError(f'{cfg.runner.type} is not supported')
+
+ ckpt = dict(model_state_dict=model.module.state_dict(),
+ optimizer=optimizer.state_dict(),
+ max_iter=cfg.runner.max_iters if 'max_iters' in cfg.runner \
+ else cfg.runner.max_epoches,
+ scheduler=scheduler.state_dict(),
+ # current_iter=curr_iter,
+ # current_epoch=curr_epoch,
+ )
+ if loss_scaler is not None:
+ # amp state_dict
+ ckpt.update(dict(scaler=loss_scaler.state_dict()))
+
+ ckpt_dir = os.path.join(cfg.work_dir, 'ckpt')
+ os.makedirs(ckpt_dir, exist_ok=True)
+
+ save_name = os.path.join(ckpt_dir, 'step%08d.pth' % curr_iter)
+ saved_ckpts = glob.glob(ckpt_dir + '/step*.pth')
+ torch.save(ckpt, save_name)
+
+ # keep the last 8 ckpts
+ if len(saved_ckpts) > 8:
+ saved_ckpts.sort()
+ os.remove(saved_ckpts.pop(0))
+
+ logger.info(f'Save model: {save_name}')
+
+
+
+if __name__ == '__main__':
+ print(TORCH_OPTIMIZER)
\ No newline at end of file
diff --git a/training/mono/utils/transform.py b/training/mono/utils/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..40f1956fb5505159367f8927e5f7d044c10f42d1
--- /dev/null
+++ b/training/mono/utils/transform.py
@@ -0,0 +1,1491 @@
+#import collections
+import collections.abc as collections
+import cv2
+import math
+import numpy as np
+import numbers
+import random
+import torch
+from imgaug import augmenters as iaa
+import matplotlib
+import matplotlib.cm
+import mono.utils.weather_aug_utils as wa
+
+"""
+Provides a set of Pytorch transforms that use OpenCV instead of PIL (Pytorch default)
+for image manipulation.
+"""
+
+class Compose(object):
+ # Composes transforms: transforms.Compose([transforms.RandScale([0.5, 2.0]), transforms.ToTensor()])
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, normals=None, other_labels=None, transform_paras=None):
+ for t in self.transforms:
+ images, labels, intrinsics, cam_models, normals, other_labels, transform_paras = t(images, labels, intrinsics, cam_models, normals, other_labels, transform_paras)
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+class ToTensor(object):
+ # Converts numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W).
+ def __init__(self, **kwargs):
+ return
+ def __call__(self, images, labels, intrinsics, cam_models=None, normals=None, other_labels=None, transform_paras=None):
+ if not isinstance(images, list) or not isinstance(labels, list) or not isinstance(intrinsics, list):
+ raise (RuntimeError("transform.ToTensor() only handle inputs/labels/intrinsics lists."))
+ if len(images) != len(intrinsics):
+ raise (RuntimeError("Numbers of images and intrinsics are not matched."))
+ if not isinstance(images[0], np.ndarray) or not isinstance(labels[0], np.ndarray):
+ raise (RuntimeError("transform.ToTensor() only handle np.ndarray for the input and label."
+ "[eg: data readed by cv2.imread()].\n"))
+ if not isinstance(intrinsics[0], list):
+ raise (RuntimeError("transform.ToTensor() only handle list for the camera intrinsics"))
+
+ if len(images[0].shape) > 3 or len(images[0].shape) < 2:
+ raise (RuntimeError("transform.ToTensor() only handle image(np.ndarray) with 3 dims or 2 dims.\n"))
+ if len(labels[0].shape) > 3 or len(labels[0].shape) < 2:
+ raise (RuntimeError("transform.ToTensor() only handle label(np.ndarray) with 3 dims or 2 dims.\n"))
+
+ if len(intrinsics[0]) >4 or len(intrinsics[0]) < 3:
+ raise (RuntimeError("transform.ToTensor() only handle intrinsic(list) with 3 sizes or 4 sizes.\n"))
+
+ for i, img in enumerate(images):
+ if len(img.shape) == 2:
+ img = np.expand_dims(img, axis=2)
+ images[i] = torch.from_numpy(img.transpose((2, 0, 1))).float()
+ for i, lab in enumerate(labels):
+ if len(lab.shape) == 2:
+ lab = np.expand_dims(lab, axis=0)
+ labels[i] = torch.from_numpy(lab).float()
+ for i, intrinsic in enumerate(intrinsics):
+ if len(intrinsic) == 3:
+ intrinsic = [intrinsic[0],] + intrinsic
+ intrinsics[i] = torch.tensor(intrinsic, dtype=torch.float)
+ if cam_models is not None:
+ for i, cam_model in enumerate(cam_models):
+ cam_models[i] = torch.from_numpy(cam_model.transpose((2, 0, 1))).float() if cam_model is not None else None
+ if normals is not None:
+ for i, normal in enumerate(normals):
+ normals[i] = torch.from_numpy(normal.transpose((2, 0, 1))).float()
+ if other_labels is not None:
+ for i, lab in enumerate(other_labels):
+ if len(lab.shape) == 2:
+ lab = np.expand_dims(lab, axis=0)
+ other_labels[i] = torch.from_numpy(lab).float()
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+class Normalize(object):
+ # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std
+ def __init__(self, mean, std=None, **kwargs):
+ if std is None:
+ assert len(mean) > 0
+ else:
+ assert len(mean) == len(std)
+ self.mean = torch.tensor(mean).float()[:, None, None]
+ self.std = torch.tensor(std).float()[:, None, None] if std is not None \
+ else torch.tensor([1.0, 1.0, 1.0]).float()[:, None, None]
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, normals=None, other_labels=None, transform_paras=None):
+ # if self.std is None:
+ # # for t, m in zip(image, self.mean):
+ # # t.sub(m)
+ # image = image - self.mean
+ # if ref_images is not None:
+ # for i, ref_i in enumerate(ref_images):
+ # ref_images[i] = ref_i - self.mean
+ # else:
+ # # for t, m, s in zip(image, self.mean, self.std):
+ # # t.sub(m).div(s)
+ # image = (image - self.mean) / self.std
+ # if ref_images is not None:
+ # for i, ref_i in enumerate(ref_images):
+ # ref_images[i] = (ref_i - self.mean) / self.std
+ for i, img in enumerate(images):
+ img = torch.div((img - self.mean), self.std)
+ images[i] = img
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+class ResizeCanonical(object):
+ """
+ Resize the input to the canonical space first, then resize the input with random sampled size.
+ In the first stage, we assume the distance holds while the camera model varies.
+ In the second stage, we aim to simulate the observation in different distance. The camera will move along the optical axis.
+ Args:
+ images: list of RGB images.
+ labels: list of depth/disparity labels.
+ other labels: other labels, such as instance segmentations, semantic segmentations...
+ """
+ def __init__(self, **kwargs):
+ self.ratio_range = kwargs['ratio_range']
+ self.canonical_focal = kwargs['focal_length']
+ self.crop_size = kwargs['crop_size']
+
+ def random_on_canonical_transform(self, image, label, intrinsic, cam_model, to_random_ratio):
+ ori_h, ori_w, _ = image.shape
+ ori_focal = (intrinsic[0] + intrinsic[1]) / 2.0
+
+ to_canonical_ratio = self.canonical_focal / ori_focal
+ to_scale_ratio = to_random_ratio
+ resize_ratio = to_canonical_ratio * to_random_ratio
+ reshape_h = int(ori_h * resize_ratio + 0.5)
+ reshape_w = int(ori_w * resize_ratio + 0.5)
+
+ image = cv2.resize(image, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ if intrinsic is not None:
+ intrinsic = [self.canonical_focal, self.canonical_focal, intrinsic[2]*resize_ratio, intrinsic[3]*resize_ratio]
+ if label is not None:
+ # number of other labels may be less than that of image
+ label = cv2.resize(label, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+ # scale the label and camera intrinsics
+ label = label / to_scale_ratio
+
+ if cam_model is not None:
+ # Should not directly resize the cam_model.
+ # Camera model should be resized in 'to canonical' stage, while it holds in 'random resizing' stage.
+ # cam_model = cv2.resize(cam_model, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ cam_model = build_camera_model(reshape_h, reshape_w, intrinsic)
+
+ return image, label, intrinsic, cam_model, to_scale_ratio
+
+ def random_on_crop_transform(self, image, label, intrinsic, cam_model, to_random_ratio):
+ ori_h, ori_w, _ = image.shape
+ crop_h, crop_w = self.crop_size
+ ori_focal = (intrinsic[0] + intrinsic[1]) / 2.0
+
+ to_canonical_ratio = self.canonical_focal / ori_focal
+
+ # random resize based on the last crop size
+ proposal_reshape_h = int(crop_h * to_random_ratio + 0.5)
+ proposal_reshape_w = int(crop_w * to_random_ratio + 0.5)
+ resize_ratio_h = proposal_reshape_h / ori_h
+ resize_ratio_w = proposal_reshape_w / ori_w
+ resize_ratio = min(resize_ratio_h, resize_ratio_w) # resize based on the long edge
+ reshape_h = int(ori_h * resize_ratio + 0.5)
+ reshape_w = int(ori_w * resize_ratio + 0.5)
+
+ to_scale_ratio = resize_ratio / to_canonical_ratio
+
+ image = cv2.resize(image, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ if intrinsic is not None:
+ intrinsic = [self.canonical_focal, self.canonical_focal, intrinsic[2]*resize_ratio, intrinsic[3]*resize_ratio]
+ if label is not None:
+ # number of other labels may be less than that of image
+ label = cv2.resize(label, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+ # scale the label and camera intrinsics
+ label = label / to_scale_ratio
+
+ if cam_model is not None:
+ # Should not directly resize the cam_model.
+ # Camera model should be resized in 'to canonical' stage, while it holds in 'random resizing' stage.
+ # cam_model = cv2.resize(cam_model, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ cam_model = build_camera_model(reshape_h, reshape_w, intrinsic)
+ return image, label, intrinsic, cam_model, to_scale_ratio
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, normals=None, other_labels=None, transform_paras=None):
+ assert len(images[0].shape) == 3 and len(labels[0].shape) == 2
+ assert labels[0].dtype == np.float
+ target_focal = (intrinsics[0][0] + intrinsics[0][1]) / 2.0
+ target_to_canonical_ratio = self.canonical_focal / target_focal
+ target_img_shape = images[0].shape
+ to_random_ratio = random.uniform(self.ratio_range[0], self.ratio_range[1])
+ to_scale_ratio = 0.0
+ for i in range(len(images)):
+ img = images[i]
+ label = labels[i] if i < len(labels) else None
+ intrinsic = intrinsics[i] if i < len(intrinsics) else None
+ cam_model = cam_models[i] if cam_models is not None and i < len(cam_models) else None
+ img, label, intrinsic, cam_model, to_scale_ratio = self.random_on_canonical_transform(
+ img, label, intrinsic, cam_model, to_random_ratio)
+
+ images[i] = img
+ if label is not None:
+ labels[i] = label
+ if intrinsic is not None:
+ intrinsics[i] = intrinsic
+ if cam_model is not None:
+ cam_models[i] = cam_model
+
+ if normals != None:
+ reshape_h, reshape_w, _ = images[0].shape
+ for i, normal in enumerate(normals):
+ normals[i] = cv2.resize(normal, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+
+ if other_labels != None:
+ # other labels are like semantic segmentations, instance segmentations, instance planes segmentations...
+ #resize_ratio = target_to_canonical_ratio * to_scale_ratio
+ #reshape_h = int(target_img_shape[0] * resize_ratio + 0.5)
+ #reshape_w = int(target_img_shape[1] * resize_ratio + 0.5)
+ reshape_h, reshape_w, _ = images[0].shape
+ for i, other_label_i in enumerate(other_labels):
+ other_labels[i] = cv2.resize(other_label_i, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+
+ if transform_paras is not None:
+ transform_paras.update(label_scale_factor = 1.0/to_scale_ratio)
+
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+
+class LabelScaleCononical(object):
+ """
+ To solve the ambiguity observation for the mono branch, i.e. different focal length (object size) with the same depth, cameras are
+ mapped to a cononical space. To mimic this, we set the focal length to a canonical one and scale the depth value. NOTE: resize the image based on the ratio can also solve this ambiguity.
+ Args:
+ images: list of RGB images.
+ labels: list of depth/disparity labels.
+ other labels: other labels, such as instance segmentations, semantic segmentations...
+ """
+ def __init__(self, **kwargs):
+ self.canonical_focal = kwargs['focal_length']
+
+ def _get_scale_ratio(self, intrinsic):
+ target_focal_x = intrinsic[0]
+ label_scale_ratio = self.canonical_focal / target_focal_x
+ pose_scale_ratio = 1.0
+ return label_scale_ratio, pose_scale_ratio
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, normals=None, other_labels=None, transform_paras=None):
+ assert len(images[0].shape) == 3 and len(labels[0].shape) == 2
+ #assert labels[0].dtype == np.float
+
+ label_scale_ratio = None
+ pose_scale_ratio = None
+
+ for i in range(len(intrinsics)):
+ img_i = images[i]
+ label_i = labels[i] if i < len(labels) else None
+ intrinsic_i = intrinsics[i].copy()
+ cam_model_i = cam_models[i] if cam_models is not None and i < len(cam_models) else None
+
+ label_scale_ratio, pose_scale_ratio = self._get_scale_ratio(intrinsic_i)
+
+ # adjust the focal length, map the current camera to the canonical space
+ intrinsics[i] = [intrinsic_i[0]*label_scale_ratio, intrinsic_i[1]*label_scale_ratio, intrinsic_i[2], intrinsic_i[3]]
+
+ # scale the label to the canonical space
+ if label_i is not None:
+ labels[i] = label_i * label_scale_ratio
+
+ if cam_model_i is not None:
+ # As the focal length is adjusted (canonical focal length), the camera model should be re-built.
+ ori_h, ori_w, _ = img_i.shape
+ cam_models[i] = build_camera_model(ori_h, ori_w, intrinsics[i])
+
+
+ if transform_paras is not None:
+ transform_paras.update(label_scale_factor = label_scale_ratio)
+
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+
+
+class ResizeKeepRatio(object):
+ """
+ Resize and pad to a given size. Hold the aspect ratio.
+ This resizing assumes that the camera model remains unchanged.
+ Args:
+ resize_size: predefined output size.
+ """
+ def __init__(self, resize_size, padding=None, ignore_label=-1, **kwargs):
+ if isinstance(resize_size, int):
+ self.resize_h = resize_size
+ self.resize_w = resize_size
+ elif isinstance(resize_size, collections.Iterable) and len(resize_size) == 2 \
+ and isinstance(resize_size[0], int) and isinstance(resize_size[1], int) \
+ and resize_size[0] > 0 and resize_size[1] > 0:
+ self.resize_h = resize_size[0]
+ self.resize_w = resize_size[1]
+ else:
+ raise (RuntimeError("crop size error.\n"))
+ if padding is None:
+ self.padding = padding
+ elif isinstance(padding, list):
+ if all(isinstance(i, numbers.Number) for i in padding):
+ self.padding = padding
+ else:
+ raise (RuntimeError("padding in Crop() should be a number list\n"))
+ if len(padding) != 3:
+ raise (RuntimeError("padding channel is not equal with 3\n"))
+ else:
+ raise (RuntimeError("padding in Crop() should be a number list\n"))
+ if isinstance(ignore_label, int):
+ self.ignore_label = ignore_label
+ else:
+ raise (RuntimeError("ignore_label should be an integer number\n"))
+ self.crop_size = kwargs['crop_size']
+ self.canonical_focal = kwargs['focal_length']
+
+ def main_data_transform(self, image, label, intrinsic, cam_model, resize_ratio, padding, to_scale_ratio):
+ """
+ Resize data first and then do the padding.
+ 'label' will be scaled.
+ """
+ h, w, _ = image.shape
+ reshape_h = int(resize_ratio * h)
+ reshape_w = int(resize_ratio * w)
+
+ pad_h, pad_w, pad_h_half, pad_w_half = padding
+
+ # resize
+ image = cv2.resize(image, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ # padding
+ image = cv2.copyMakeBorder(
+ image,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=self.padding)
+
+ if label is not None:
+ # label = cv2.resize(label, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+ label = resize_depth_preserve(label, (reshape_h, reshape_w))
+ label = cv2.copyMakeBorder(
+ label,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=self.ignore_label)
+ # scale the label
+ label = label / to_scale_ratio
+
+ # Resize, adjust principle point
+ if intrinsic is not None:
+ intrinsic[2] = intrinsic[2] * resize_ratio
+ intrinsic[3] = intrinsic[3] * resize_ratio
+
+ if cam_model is not None:
+ #cam_model = cv2.resize(cam_model, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ cam_model = build_camera_model(reshape_h, reshape_w, intrinsic)
+ cam_model = cv2.copyMakeBorder(
+ cam_model,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=self.ignore_label)
+
+ # Pad, adjust the principle point
+ if intrinsic is not None:
+ intrinsic[2] = intrinsic[2] + pad_w_half
+ intrinsic[3] = intrinsic[3] + pad_h_half
+ return image, label, intrinsic, cam_model
+
+ def get_label_scale_factor(self, image, intrinsic, resize_ratio):
+ ori_h, ori_w, _ = image.shape
+ crop_h, crop_w = self.crop_size
+ ori_focal = (intrinsic[0] + intrinsic[1]) / 2.0 #intrinsic[0] #
+
+ to_canonical_ratio = self.canonical_focal / ori_focal
+ to_scale_ratio = resize_ratio / to_canonical_ratio
+ return to_scale_ratio
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, normals=None, other_labels=None, transform_paras=None):
+ target_h, target_w, _ = images[0].shape
+ resize_ratio_h = self.resize_h / target_h
+ resize_ratio_w = self.resize_w / target_w
+ resize_ratio = min(resize_ratio_h, resize_ratio_w)
+ reshape_h = int(resize_ratio * target_h)
+ reshape_w = int(resize_ratio * target_w)
+ pad_h = max(self.resize_h - reshape_h, 0)
+ pad_w = max(self.resize_w - reshape_w, 0)
+ pad_h_half = int(pad_h / 2)
+ pad_w_half = int(pad_w / 2)
+
+ pad_info = [pad_h, pad_w, pad_h_half, pad_w_half]
+ to_scale_ratio = self.get_label_scale_factor(images[0], intrinsics[0], resize_ratio)
+
+ for i in range(len(images)):
+ img = images[i]
+ label = labels[i] if i < len(labels) else None
+ intrinsic = intrinsics[i] if i < len(intrinsics) else None
+ cam_model = cam_models[i] if cam_models is not None and i < len(cam_models) else None
+ img, label, intrinsic, cam_model = self.main_data_transform(
+ img, label, intrinsic, cam_model, resize_ratio, pad_info, to_scale_ratio)
+ images[i] = img
+ if label is not None:
+ labels[i] = label
+ if intrinsic is not None:
+ intrinsics[i] = intrinsic
+ if cam_model is not None:
+ cam_models[i] = cam_model
+
+ if normals is not None:
+ for i, normal in enumerate(normals):
+ normal = cv2.resize(normal, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+ # pad
+ normals[i] = cv2.copyMakeBorder(
+ normal,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=0)
+
+ if other_labels is not None:
+
+ for i, other_lab in enumerate(other_labels):
+ # resize
+ other_lab = cv2.resize(other_lab, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+ # pad
+ other_labels[i] = cv2.copyMakeBorder(
+ other_lab,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=self.ignore_label)
+
+
+ if transform_paras is not None:
+ transform_paras.update(pad=[pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half])
+ if 'label_scale_factor' in transform_paras:
+ transform_paras['label_scale_factor'] = transform_paras['label_scale_factor'] * 1.0 / to_scale_ratio
+ else:
+ transform_paras.update(label_scale_factor=1.0/to_scale_ratio)
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+class KeepResizeCanoSize(object):
+ """
+ Resize and pad to a given size. Hold the aspect ratio.
+ This resizing assumes that the camera model remains unchanged.
+ Args:
+ resize_size: predefined output size.
+ """
+ def __init__(self, resize_size, padding=None, ignore_label=-1, **kwargs):
+ if isinstance(resize_size, int):
+ self.resize_h = resize_size
+ self.resize_w = resize_size
+ elif isinstance(resize_size, collections.Iterable) and len(resize_size) == 2 \
+ and isinstance(resize_size[0], int) and isinstance(resize_size[1], int) \
+ and resize_size[0] > 0 and resize_size[1] > 0:
+ self.resize_h = resize_size[0]
+ self.resize_w = resize_size[1]
+ else:
+ raise (RuntimeError("crop size error.\n"))
+ if padding is None:
+ self.padding = padding
+ elif isinstance(padding, list):
+ if all(isinstance(i, numbers.Number) for i in padding):
+ self.padding = padding
+ else:
+ raise (RuntimeError("padding in Crop() should be a number list\n"))
+ if len(padding) != 3:
+ raise (RuntimeError("padding channel is not equal with 3\n"))
+ else:
+ raise (RuntimeError("padding in Crop() should be a number list\n"))
+ if isinstance(ignore_label, int):
+ self.ignore_label = ignore_label
+ else:
+ raise (RuntimeError("ignore_label should be an integer number\n"))
+ self.crop_size = kwargs['crop_size']
+ self.canonical_focal = kwargs['focal_length']
+
+ def main_data_transform(self, image, label, intrinsic, cam_model, resize_ratio, padding, to_scale_ratio):
+ """
+ Resize data first and then do the padding.
+ 'label' will be scaled.
+ """
+ h, w, _ = image.shape
+ reshape_h = int(resize_ratio * h)
+ reshape_w = int(resize_ratio * w)
+
+ pad_h, pad_w, pad_h_half, pad_w_half = padding
+
+ # resize
+ image = cv2.resize(image, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ # padding
+ image = cv2.copyMakeBorder(
+ image,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=self.padding)
+
+ if label is not None:
+ # label = cv2.resize(label, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+ label = resize_depth_preserve(label, (reshape_h, reshape_w))
+ label = cv2.copyMakeBorder(
+ label,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=self.ignore_label)
+ # scale the label
+ label = label / to_scale_ratio
+
+ # Resize, adjust principle point
+ if intrinsic is not None:
+ intrinsic[2] = intrinsic[2] * resize_ratio
+ intrinsic[3] = intrinsic[3] * resize_ratio
+
+ if cam_model is not None:
+ #cam_model = cv2.resize(cam_model, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ cam_model = build_camera_model(reshape_h, reshape_w, intrinsic)
+ cam_model = cv2.copyMakeBorder(
+ cam_model,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=self.ignore_label)
+
+ # Pad, adjust the principle point
+ if intrinsic is not None:
+ intrinsic[2] = intrinsic[2] + pad_w_half
+ intrinsic[3] = intrinsic[3] + pad_h_half
+ return image, label, intrinsic, cam_model
+
+ # def get_label_scale_factor(self, image, intrinsic, resize_ratio):
+ # ori_h, ori_w, _ = image.shape
+ # crop_h, crop_w = self.crop_size
+ # ori_focal = intrinsic[0] #(intrinsic[0] + intrinsic[1]) / 2.0
+
+ # to_canonical_ratio = self.canonical_focal / ori_focal
+ # to_scale_ratio = resize_ratio / to_canonical_ratio
+ # return to_scale_ratio
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, normals=None, other_labels=None, transform_paras=None):
+ target_h, target_w, _ = images[0].shape
+ ori_focal = intrinsics[0][0]
+ to_canonical_ratio = self.canonical_focal / ori_focal
+
+ resize_ratio = to_canonical_ratio
+ reshape_h = int(resize_ratio * target_h)
+ reshape_w = int(resize_ratio * target_w)
+
+ pad_h = 32 - reshape_h % 32
+ pad_w = 32 - reshape_w % 32
+ pad_h_half = int(pad_h / 2)
+ pad_w_half = int(pad_w / 2)
+
+ pad_info = [pad_h, pad_w, pad_h_half, pad_w_half]
+ to_scale_ratio = 1.0
+
+ for i in range(len(images)):
+ img = images[i]
+ label = labels[i] if i < len(labels) else None
+ intrinsic = intrinsics[i] if i < len(intrinsics) else None
+ cam_model = cam_models[i] if cam_models is not None and i < len(cam_models) else None
+ img, label, intrinsic, cam_model = self.main_data_transform(
+ img, label, intrinsic, cam_model, resize_ratio, pad_info, to_scale_ratio)
+ images[i] = img
+ if label is not None:
+ labels[i] = label
+ if intrinsic is not None:
+ intrinsics[i] = intrinsic
+ if cam_model is not None:
+ cam_models[i] = cam_model
+
+ if normals is not None:
+
+ for i, normal in enumerate(normals):
+ # resize
+ normal = cv2.resize(normal, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+ # pad
+ normals[i] = cv2.copyMakeBorder(
+ normal,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=0)
+
+ if other_labels is not None:
+
+ for i, other_lab in enumerate(other_labels):
+ # resize
+ other_lab = cv2.resize(other_lab, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+ # pad
+ other_labels[i] = cv2.copyMakeBorder(
+ other_lab,
+ pad_h_half,
+ pad_h - pad_h_half,
+ pad_w_half,
+ pad_w - pad_w_half,
+ cv2.BORDER_CONSTANT,
+ value=self.ignore_label)
+
+
+ if transform_paras is not None:
+ transform_paras.update(pad=[pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half])
+ if 'label_scale_factor' in transform_paras:
+ transform_paras['label_scale_factor'] = transform_paras['label_scale_factor'] * 1.0 / to_scale_ratio
+ else:
+ transform_paras.update(label_scale_factor=1.0/to_scale_ratio)
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+
+class RandomCrop(object):
+ """Crops the given ndarray image (H*W*C or H*W).
+ Args:
+ size (sequence or int): Desired output size of the crop. If size is an
+ int instead of sequence like (h, w), a square crop (size, size) is made.
+ """
+ def __init__(self, crop_size, crop_type='center', padding=None, ignore_label=-1, **kwargs):
+ if isinstance(crop_size, int):
+ self.crop_h = crop_size
+ self.crop_w = crop_size
+ elif isinstance(crop_size, collections.Iterable) and len(crop_size) == 2 \
+ and isinstance(crop_size[0], int) and isinstance(crop_size[1], int) \
+ and crop_size[0] > 0 and crop_size[1] > 0:
+ self.crop_h = crop_size[0]
+ self.crop_w = crop_size[1]
+ else:
+ raise (RuntimeError("crop size error.\n"))
+ if crop_type == 'center' or crop_type == 'rand' or crop_type=='rand_in_field':
+ self.crop_type = crop_type
+ else:
+ raise (RuntimeError("crop type error: rand | center | rand_in_field \n"))
+ if padding is None:
+ self.padding = padding
+ elif isinstance(padding, list):
+ if all(isinstance(i, numbers.Number) for i in padding):
+ self.padding = padding
+ else:
+ raise (RuntimeError("padding in Crop() should be a number list\n"))
+ if len(padding) != 3:
+ raise (RuntimeError("padding channel is not equal with 3\n"))
+ else:
+ raise (RuntimeError("padding in Crop() should be a number list\n"))
+ if isinstance(ignore_label, int):
+ self.ignore_label = ignore_label
+ else:
+ raise (RuntimeError("ignore_label should be an integer number\n"))
+
+
+ def cal_padding_paras(self, h, w):
+ # padding if current size is not satisfied
+ pad_h = max(self.crop_h - h, 0)
+ pad_w = max(self.crop_w - w, 0)
+ pad_h_half = int(pad_h / 2)
+ pad_w_half = int(pad_w / 2)
+ return pad_h, pad_w, pad_h_half, pad_w_half
+
+ def cal_cropping_paras(self, h, w, intrinsic):
+ u0 = intrinsic[2]
+ v0 = intrinsic[3]
+ if self.crop_type == 'rand':
+ h_min = 0
+ h_max = h - self.crop_h
+ w_min = 0
+ w_max = w - self.crop_w
+ elif self.crop_type == 'center':
+ h_min = (h - self.crop_h) / 2
+ h_max = (h - self.crop_h) / 2
+ w_min = (w - self.crop_w) / 2
+ w_max = (w - self.crop_w) / 2
+ else: # rand in field
+ h_min = min(max(0, v0 - 0.75*self.crop_h), h-self.crop_h)
+ h_max = min(max(v0 - 0.25*self.crop_h, 0), h-self.crop_h)
+ w_min = min(max(0, u0 - 0.75*self.crop_w), w-self.crop_w)
+ w_max = min(max(u0 - 0.25*self.crop_w, 0), w-self.crop_w)
+
+ h_off = random.randint(int(h_min), int(h_max))
+ w_off = random.randint(int(w_min), int(w_max))
+ return h_off, w_off
+
+ def main_data_transform(self, image, label, intrinsic, cam_model,
+ pad_h, pad_w, pad_h_half, pad_w_half, h_off, w_off):
+
+ # padding if current size is not satisfied
+ if pad_h > 0 or pad_w > 0:
+ if self.padding is None:
+ raise (RuntimeError("depthtransform.Crop() need padding while padding argument is None\n"))
+ image = cv2.copyMakeBorder(image, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.padding)
+ if label is not None:
+ label = cv2.copyMakeBorder(label, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.ignore_label)
+ if cam_model is not None:
+ cam_model = cv2.copyMakeBorder(cam_model, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.ignore_label)
+
+ # cropping
+ image = image[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w]
+ if label is not None:
+ label = label[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w]
+ if cam_model is not None:
+ cam_model = cam_model[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w]
+
+ if intrinsic is not None:
+ intrinsic[2] = intrinsic[2] + pad_w_half - w_off
+ intrinsic[3] = intrinsic[3] + pad_h_half - h_off
+ return image, label, intrinsic, cam_model
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, normals=None, other_labels=None, transform_paras=None):
+ if 'random_crop_size' in transform_paras and transform_paras['random_crop_size'] is not None \
+ and (transform_paras['random_crop_size'][0] + transform_paras['random_crop_size'][1] > 500):
+ self.crop_h = int(transform_paras['random_crop_size'][0].item())
+ self.crop_w = int(transform_paras['random_crop_size'][1].item())
+ target_img = images[0]
+ target_h, target_w, _ = target_img.shape
+ target_intrinsic = intrinsics[0]
+ pad_h, pad_w, pad_h_half, pad_w_half = self.cal_padding_paras(target_h, target_w)
+ h_off, w_off = self.cal_cropping_paras(target_h+pad_h, target_w+pad_w, target_intrinsic)
+
+ for i in range(len(images)):
+ img = images[i]
+ label = labels[i] if i < len(labels) else None
+ intrinsic = intrinsics[i].copy() if i < len(intrinsics) else None
+ cam_model = cam_models[i] if cam_models is not None and i < len(cam_models) else None
+ img, label, intrinsic, cam_model = self.main_data_transform(
+ img, label, intrinsic, cam_model,
+ pad_h, pad_w, pad_h_half, pad_w_half, h_off, w_off)
+ images[i] = img
+ if label is not None:
+ labels[i] = label
+ if intrinsic is not None:
+ intrinsics[i] = intrinsic
+ if cam_model is not None:
+ cam_models[i] = cam_model
+ pad=[pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half]
+ if normals is not None:
+ for i, normal in enumerate(normals):
+ # padding if current size is not satisfied
+ normal = cv2.copyMakeBorder(normal, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=0)
+ normals[i] = normal[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w]
+ if other_labels is not None:
+ for i, other_lab in enumerate(other_labels):
+ # padding if current size is not satisfied
+ other_lab = cv2.copyMakeBorder(other_lab, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.ignore_label)
+ other_labels[i] = other_lab[h_off:h_off+self.crop_h, w_off:w_off+self.crop_w]
+ if transform_paras is not None:
+ transform_paras.update(dict(pad=pad))
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+
+class RandomResize(object):
+ """
+ Random resize the image. During this process, the camera model is hold, and thus the depth label is scaled.
+ Args:
+ images: list of RGB images.
+ labels: list of depth/disparity labels.
+ other labels: other labels, such as instance segmentations, semantic segmentations...
+ """
+ def __init__(self, ratio_range=(0.85, 1.15), prob=0.5, is_lidar=True, **kwargs):
+ self.ratio_range = ratio_range
+ self.is_lidar = is_lidar
+ self.prob = prob
+
+ def random_resize(self, image, label, intrinsic, cam_model, to_random_ratio):
+ ori_h, ori_w, _ = image.shape
+
+ resize_ratio = to_random_ratio
+ label_scale_ratio = 1.0 / resize_ratio
+ reshape_h = int(ori_h * resize_ratio + 0.5)
+ reshape_w = int(ori_w * resize_ratio + 0.5)
+
+ image = cv2.resize(image, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ if intrinsic is not None:
+ intrinsic = [intrinsic[0], intrinsic[1], intrinsic[2]*resize_ratio, intrinsic[3]*resize_ratio]
+ if label is not None:
+ if self.is_lidar:
+ label = resize_depth_preserve(label, (reshape_h, reshape_w))
+ else:
+ label = cv2.resize(label, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+ # scale the label
+ label = label * label_scale_ratio
+
+ if cam_model is not None:
+ # Should not directly resize the cam_model.
+ # Camera model should be resized in 'to canonical' stage, while it holds in 'random resizing' stage.
+ # cam_model = cv2.resize(cam_model, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ cam_model = build_camera_model(reshape_h, reshape_w, intrinsic)
+
+ return image, label, intrinsic, cam_model, label_scale_ratio
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, normals=None, other_labels=None, transform_paras=None):
+ assert len(images[0].shape) == 3 and len(labels[0].shape) == 2
+ assert labels[0].dtype == np.float
+ # target_focal = (intrinsics[0][0] + intrinsics[0][1]) / 2.0
+ # target_to_canonical_ratio = self.canonical_focal / target_focal
+ # target_img_shape = images[0].shape
+ prob = random.uniform(0, 1)
+ if prob < self.prob:
+ to_random_ratio = random.uniform(self.ratio_range[0], self.ratio_range[1])
+ else:
+ to_random_ratio = 1.0
+ label_scale_ratio = 0.0
+ for i in range(len(images)):
+ img = images[i]
+ label = labels[i] if i < len(labels) else None
+ intrinsic = intrinsics[i].copy() if i < len(intrinsics) else None
+ cam_model = cam_models[i] if cam_models is not None and i < len(cam_models) else None
+ img, label, intrinsic, cam_model, label_scale_ratio = self.random_resize(
+ img, label, intrinsic, cam_model, to_random_ratio)
+
+ images[i] = img
+ if label is not None:
+ labels[i] = label
+ if intrinsic is not None:
+ intrinsics[i] = intrinsic.copy()
+ if cam_model is not None:
+ cam_models[i] = cam_model
+
+ if normals != None:
+ reshape_h, reshape_w, _ = images[0].shape
+ for i, norm in enumerate(normals):
+ normals[i] = cv2.resize(norm, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+
+
+ if other_labels != None:
+ # other labels are like semantic segmentations, instance segmentations, instance planes segmentations...
+ #resize_ratio = target_to_canonical_ratio * to_scale_ratio
+ #reshape_h = int(target_img_shape[0] * resize_ratio + 0.5)
+ #reshape_w = int(target_img_shape[1] * resize_ratio + 0.5)
+ reshape_h, reshape_w, _ = images[0].shape
+ for i, other_label_i in enumerate(other_labels):
+ other_labels[i] = cv2.resize(other_label_i, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+
+ if transform_paras is not None:
+ if 'label_scale_factor' in transform_paras:
+ transform_paras['label_scale_factor'] = transform_paras['label_scale_factor'] * label_scale_ratio
+ else:
+ transform_paras.update(label_scale_factor = label_scale_ratio)
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+class RandomEdgeMask(object):
+ """
+ Random mask the input and labels.
+ Args:
+ images: list of RGB images.
+ labels: list of depth/disparity labels.
+ other labels: other labels, such as instance segmentations, semantic segmentations...
+ """
+ def __init__(self, mask_maxsize=32, prob=0.5, rgb_invalid=[0,0,0], label_invalid=-1,**kwargs):
+ self.mask_maxsize = mask_maxsize
+ self.prob = prob
+ self.rgb_invalid = rgb_invalid
+ self.label_invalid = label_invalid
+
+ def mask_edge(self, image, mask_edgesize, mask_value):
+ H, W = image.shape[0], image.shape[1]
+ # up
+ image[0:mask_edgesize[0], :, ...] = mask_value
+ # down
+ image[H-mask_edgesize[1]:H, :, ...] = mask_value
+ # left
+ image[:, 0:mask_edgesize[2], ...] = mask_value
+ # right
+ image[:, W-mask_edgesize[3]:W, ...] = mask_value
+
+ return image
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, normals=None, other_labels=None, transform_paras=None):
+ assert len(images[0].shape) == 3 and len(labels[0].shape) == 2
+ assert labels[0].dtype == np.float
+
+ prob = random.uniform(0, 1)
+ if prob > self.prob:
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+ mask_edgesize = random.sample(range(self.mask_maxsize), 4) #[up, down, left, right]
+ for i in range(len(images)):
+ img = images[i]
+ label = labels[i] if i < len(labels) else None
+ img = self.mask_edge(img, mask_edgesize, self.rgb_invalid)
+
+ images[i] = img
+ if label is not None:
+ label = self.mask_edge(label, mask_edgesize, self.label_invalid)
+ labels[i] = label
+
+ if normals != None:
+ for i, normal in enumerate(normals):
+ normals[i] = self.mask_edge(normal, mask_edgesize, mask_value=0)
+
+ if other_labels != None:
+ # other labels are like semantic segmentations, instance segmentations, instance planes segmentations...
+ for i, other_label_i in enumerate(other_labels):
+ other_labels[i] = self.mask_edge(other_label_i, mask_edgesize, self.label_invalid)
+
+ if transform_paras is not None:
+ pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
+ new_pad = [max(mask_edgesize[0], pad[0]), max(mask_edgesize[1], pad[1]), max(mask_edgesize[2], pad[2]), max(mask_edgesize[3], pad[3])]
+ transform_paras.update(dict(pad=new_pad))
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+
+class AdjustSize(object):
+ """Crops the given ndarray image (H*W*C or H*W).
+ Args:
+ size (sequence or int): Desired output size of the crop. If size is an
+ int instead of sequence like (h, w), a square crop (size, size) is made.
+ """
+ def __init__(self, padding=None, ignore_label=-1, **kwargs):
+
+ if padding is None:
+ self.padding = padding
+ elif isinstance(padding, list):
+ if all(isinstance(i, numbers.Number) for i in padding):
+ self.padding = padding
+ else:
+ raise (RuntimeError("padding in Crop() should be a number list\n"))
+ if len(padding) != 3:
+ raise (RuntimeError("padding channel is not equal with 3\n"))
+ else:
+ raise (RuntimeError("padding in Crop() should be a number list\n"))
+ if isinstance(ignore_label, int):
+ self.ignore_label = ignore_label
+ else:
+ raise (RuntimeError("ignore_label should be an integer number\n"))
+
+ def get_pad_paras(self, h, w):
+ pad_h = 32 - h % 32 if h %32 != 0 else 0
+ pad_w = 32 - w % 32 if w %32 != 0 else 0
+ pad_h_half = int(pad_h // 2)
+ pad_w_half = int(pad_w // 2)
+ return pad_h, pad_w, pad_h_half, pad_w_half
+
+ def main_data_transform(self, image, label, intrinsic, cam_model):
+ h, w, _ = image.shape
+ pad_h, pad_w, pad_h_half, pad_w_half = self.get_pad_paras(h=h, w=w)
+ if pad_h > 0 or pad_w > 0:
+ if self.padding is None:
+ raise (RuntimeError("depthtransform.Crop() need padding while padding argument is None\n"))
+ image = cv2.copyMakeBorder(image, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.padding)
+ if label is not None:
+ label = cv2.copyMakeBorder(label, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.ignore_label)
+ if cam_model is not None:
+ cam_model = cv2.copyMakeBorder(cam_model, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.ignore_label)
+
+ if intrinsic is not None:
+ intrinsic[2] = intrinsic[2] + pad_w_half
+ intrinsic[3] = intrinsic[3] + pad_h_half
+ pad=[pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half]
+ return image, label, intrinsic, cam_model, pad
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, normals=None, other_labels=None, transform_paras=None):
+ target_img = images[0]
+ target_h, target_w, _ = target_img.shape
+ for i in range(len(images)):
+ img = images[i]
+ label = labels[i] if i < len(labels) else None
+ intrinsic = intrinsics[i] if i < len(intrinsics) else None
+ cam_model = cam_models[i] if cam_models is not None and i < len(cam_models) else None
+ img, label, intrinsic, cam_model, pad = self.main_data_transform(
+ img, label, intrinsic, cam_model)
+ images[i] = img
+ if label is not None:
+ labels[i] = label
+ if intrinsic is not None:
+ intrinsics[i] = intrinsic
+ if cam_model is not None:
+ cam_models[i] = cam_model
+
+ if transform_paras is not None:
+ transform_paras.update(dict(pad=pad))
+ if normals is not None:
+ pad_h, pad_w, pad_h_half, pad_w_half = self.get_pad_paras(h=target_h, w=target_w)
+ for i, normal in enumerate(normals):
+ normals[i] = cv2.copyMakeBorder(normal, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=0)
+
+ if other_labels is not None:
+ pad_h, pad_w, pad_h_half, pad_w_half = self.get_pad_paras(h=target_h, w=target_w)
+ for i, other_lab in enumerate(other_labels):
+ other_labels[i] = cv2.copyMakeBorder(other_lab, pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half, cv2.BORDER_CONSTANT, value=self.ignore_label)
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+
+class RandomHorizontalFlip(object):
+ def __init__(self, prob=0.5, **kwargs):
+ self.p = prob
+
+ def main_data_transform(self, image, label, intrinsic, cam_model, rotate):
+ if rotate:
+ image = cv2.flip(image, 1)
+ if label is not None:
+ label = cv2.flip(label, 1)
+ if intrinsic is not None:
+ h, w, _ = image.shape
+ intrinsic[2] = w - intrinsic[2]
+ intrinsic[3] = h - intrinsic[3]
+ if cam_model is not None:
+ cam_model = cv2.flip(cam_model, 1)
+ cam_model[:, :, 0] = cam_model[:, :, 0] * -1
+ cam_model[:, :, 2] = cam_model[:, :, 2] * -1
+ return image, label, intrinsic, cam_model
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, normals=None, other_labels=None, transform_paras=None):
+ rotate = random.random() > self.p
+
+ for i in range(len(images)):
+ img = images[i]
+ label = labels[i] if i < len(labels) else None
+ intrinsic = intrinsics[i] if i < len(intrinsics) else None
+ cam_model = cam_models[i] if cam_models is not None and i < len(cam_models) else None
+ img, label, intrinsic, cam_model = self.main_data_transform(
+ img, label, intrinsic, cam_model, rotate)
+ images[i] = img
+ if label is not None:
+ labels[i] = label
+ if intrinsic is not None:
+ intrinsics[i] = intrinsic
+ if cam_model is not None:
+ cam_models[i] = cam_model
+ if normals is not None:
+ for i, normal in enumerate(normals):
+ if rotate:
+ normal = cv2.flip(normal, 1)
+ normal[:, :, 0] = -normal[:, :, 0] # NOTE: check the direction of normal coordinates axis, this is used in https://github.com/baegwangbin/surface_normal_uncertainty
+ normals[i] = normal
+
+ if other_labels is not None:
+ for i, other_lab in enumerate(other_labels):
+ if rotate:
+ other_lab = cv2.flip(other_lab, 1)
+ other_labels[i] = other_lab
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+class RandomBlur(object):
+ def __init__(self,
+ aver_kernal=(2, 10),
+ motion_kernal=(5, 15),
+ angle=[-80, 80],
+ prob=0.3,
+ **kwargs):
+
+ gaussian_blur = iaa.AverageBlur(k=aver_kernal)
+ motion_blur = iaa.MotionBlur(k=motion_kernal, angle=angle)
+ zoom_blur = iaa.imgcorruptlike.ZoomBlur(severity=1)
+ self.prob = prob
+ self.blurs = [gaussian_blur, motion_blur, zoom_blur]
+
+ def blur(self, imgs, id):
+ blur_mtd = self.blurs[id]
+ return blur_mtd(images=imgs)
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, normals=None, other_labels=None, transform_paras=None):
+ prob = random.random()
+ if prob < self.prob:
+ id = random.randint(0, len(self.blurs)-1)
+ images = self.blur(images, id)
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+class RGBCompresion(object):
+ def __init__(self, prob=0.1, compression=(0, 50), **kwargs):
+ self.rgb_compress = iaa.Sequential(
+ [
+ iaa.JpegCompression(compression=compression),
+ ],
+ random_order=True,
+ )
+ self.prob = prob
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, normals=None, other_labels=None, transform_paras=None):
+ if random.random() < self.prob:
+ images = self.rgb_compress(images=images)
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+
+class RGB2BGR(object):
+ # Converts image from RGB order to BGR order, for model initialized from Caffe
+ def __init__(self, **kwargs):
+ return
+ def __call__(self, images, labels, intrinsics, cam_models=None, normals=None, other_labels=None, transform_paras=None):
+ for i, img in enumerate(images):
+ images[i] = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+
+class BGR2RGB(object):
+ # Converts image from BGR order to RGB order, for model initialized from Pytorch
+ def __init__(self, **kwargs):
+ return
+ def __call__(self, images, labels, intrinsics, cam_models=None, normals=None, other_labels=None, transform_paras=None):
+ for i, img in enumerate(images):
+ images[i] = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+
+class PhotoMetricDistortion(object):
+ """Apply photometric distortion to image sequentially, every transformation
+ is applied with a probability of 0.5. The position of random contrast is in
+ second or second to last.
+ 1. random brightness
+ 2. random contrast (mode 0)
+ 3. convert color from BGR to HSV
+ 4. random saturation
+ 5. random hue
+ 6. convert color from HSV to BGR
+ 7. random contrast (mode 1)
+ Args:
+ brightness_delta (int): delta of brightness.
+ contrast_range (tuple): range of contrast.
+ saturation_range (tuple): range of saturation.
+ hue_delta (int): delta of hue.
+ """
+
+ def __init__(self,
+ brightness_delta=32,
+ contrast_range=(0.5, 1.5),
+ saturation_range=(0.5, 1.5),
+ hue_delta=18,
+ to_gray_prob=0.3,
+ distortion_prob=0.3,
+ **kwargs):
+ self.brightness_delta = brightness_delta
+ self.contrast_lower, self.contrast_upper = contrast_range
+ self.saturation_lower, self.saturation_upper = saturation_range
+ self.hue_delta = hue_delta
+ self.gray_aug = iaa.Grayscale(alpha=(0.8, 1.0))
+ self.to_gray_prob = to_gray_prob
+ self.distortion_prob = distortion_prob
+
+ def convert(self, img, alpha=1.0, beta=0.0):
+ """Multiple with alpha and add beat with clip."""
+ img = img.astype(np.float32) * alpha + beta
+ img = np.clip(img, 0, 255)
+ return img.astype(np.uint8)
+
+ def brightness(self, img, beta, do):
+ """Brightness distortion."""
+ if do:
+ # beta = random.uniform(-self.brightness_delta,
+ # self.brightness_delta)
+ img = self.convert(
+ img,
+ beta=beta)
+ return img
+
+ def contrast(self, img, alpha, do):
+ """Contrast distortion."""
+ if do:
+ #alpha = random.uniform(self.contrast_lower, self.contrast_upper)
+ img = self.convert(
+ img,
+ alpha=alpha)
+ return img
+
+ def saturation(self, img, alpha, do):
+ """Saturation distortion."""
+ if do:
+ # alpha = random.uniform(self.saturation_lower,
+ # self.saturation_upper)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
+ img[:, :, 1] = self.convert(
+ img[:, :, 1],
+ alpha=alpha)
+ img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
+ return img
+
+ def hue(self, img, rand_hue, do):
+ """Hue distortion."""
+ if do:
+ # rand_hue = random.randint(-self.hue_delta, self.hue_delta)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
+ img[:, :, 0] = (img[:, :, 0].astype(int) + rand_hue) % 180
+ img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
+ return img
+
+ def rgb2gray(self, img):
+ img = self.gray_aug(image=img)
+ return img
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, normals=None, other_labels=None, transform_paras=None):
+ """Call function to perform photometric distortion on images.
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ dict: Result dict with images distorted.
+ """
+ brightness_beta = random.uniform(-self.brightness_delta, self.brightness_delta)
+ brightness_do = random.random() < self.distortion_prob
+
+ contrast_alpha = random.uniform(self.contrast_lower, self.contrast_upper)
+ contrast_do = random.random() < self.distortion_prob
+
+ saturate_alpha = random.uniform(self.saturation_lower, self.saturation_upper)
+ saturate_do = random.random() < self.distortion_prob
+
+ rand_hue = random.randint(-self.hue_delta, self.hue_delta)
+ rand_hue_do = random.random() < self.distortion_prob
+
+ # mode == 0 --> do random contrast first
+ # mode == 1 --> do random contrast last
+ mode = 1 if random.random() > 0.5 else 2
+ for i, img in enumerate(images):
+ if random.random() < self.to_gray_prob:
+ img = self.rgb2gray(img)
+ else:
+ # random brightness
+ img = self.brightness(img, brightness_beta, brightness_do)
+
+ if mode == 1:
+ img = self.contrast(img, contrast_alpha, contrast_do)
+
+ # random saturation
+ img = self.saturation(img, saturate_alpha, saturate_do)
+
+ # random hue
+ img = self.hue(img, rand_hue, rand_hue_do)
+
+ # random contrast
+ if mode == 0:
+ img = self.contrast(img, contrast_alpha, contrast_do)
+ images[i] = img
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+class Weather(object):
+ """Apply the following weather augmentations to data.
+ Args:
+ prob (float): probability to enforce the weather augmentation.
+ """
+
+ def __init__(self,
+ prob=0.3,
+ **kwargs):
+ snow = iaa.FastSnowyLandscape(
+ lightness_threshold=[50, 100],
+ lightness_multiplier=(1.2, 2)
+ )
+ cloud = iaa.Clouds()
+ fog = iaa.Fog()
+ snow_flakes = iaa.Snowflakes(flake_size=(0.2, 0.4), speed=(0.001, 0.03)) #iaa.imgcorruptlike.Snow(severity=2)#
+ rain = iaa.Rain(speed=(0.1, 0.3), drop_size=(0.1, 0.3))
+ # rain_drops = RainDrop_Augmentor()
+ self.aug_list = [
+ snow, cloud, fog, snow_flakes, rain,
+ #wa.add_sun_flare, wa.darken, wa.random_brightness,
+ ]
+ self.prob = prob
+
+ def aug_with_weather(self, imgs, id):
+ weather = self.aug_list[id]
+ if id <5:
+ return weather(images=imgs)
+ else:
+ return weather(imgs)
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, normals=None, other_labels=None, transform_paras=None):
+ """Call function to perform photometric distortion on images.
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ dict: Result dict with images distorted.
+ """
+ if random.random() < self.prob:
+ select_id = np.random.randint(0, high=len(self.aug_list))
+ images = self.aug_with_weather(images, select_id)
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+
+def resize_depth_preserve(depth, shape):
+ """
+ Resizes depth map preserving all valid depth pixels
+ Multiple downsampled points can be assigned to the same pixel.
+
+ Parameters
+ ----------
+ depth : np.array [h,w]
+ Depth map
+ shape : tuple (H,W)
+ Output shape
+
+ Returns
+ -------
+ depth : np.array [H,W,1]
+ Resized depth map
+ """
+ # Store dimensions and reshapes to single column
+ depth = np.squeeze(depth)
+ h, w = depth.shape
+ x = depth.reshape(-1)
+ # Create coordinate grid
+ uv = np.mgrid[:h, :w].transpose(1, 2, 0).reshape(-1, 2)
+ # Filters valid points
+ idx = x > 0
+ crd, val = uv[idx], x[idx]
+ # Downsamples coordinates
+ crd[:, 0] = (crd[:, 0] * (shape[0] / h) + 0.5).astype(np.int32)
+ crd[:, 1] = (crd[:, 1] * (shape[1] / w) + 0.5).astype(np.int32)
+ # Filters points inside image
+ idx = (crd[:, 0] < shape[0]) & (crd[:, 1] < shape[1])
+ crd, val = crd[idx], val[idx]
+ # Creates downsampled depth image and assigns points
+ depth = np.zeros(shape)
+ depth[crd[:, 0], crd[:, 1]] = val
+ # Return resized depth map
+ return depth
+
+
+def gray_to_colormap(img, cmap='rainbow', max_value=None):
+ """
+ Transfer gray map to matplotlib colormap
+ """
+ assert img.ndim == 2
+
+ img[img<0] = 0
+ mask_invalid = img < 1e-10
+ if max_value == None:
+ img = img / (img.max() + 1e-8)
+ else:
+ img = img / (max_value + 1e-8)
+ norm = matplotlib.colors.Normalize(vmin=0, vmax=1.1)
+ cmap_m = matplotlib.cm.get_cmap(cmap)
+ map = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap_m)
+ colormap = (map.to_rgba(img)[:, :, :3] * 255).astype(np.uint8)
+ colormap[mask_invalid] = 0
+ return colormap
+
+
+class LiDarResizeCanonical(object):
+ """
+ Resize the input to the canonical space first, then resize the input with random sampled size.
+ In the first stage, we assume the distance holds while the camera model varies.
+ In the second stage, we aim to simulate the observation in different distance. The camera will move along the optical axis.
+ """
+ def __init__(self, **kwargs):
+ self.ratio_range = kwargs['ratio_range']
+ self.canonical_focal = kwargs['focal_length']
+ self.crop_size = kwargs['crop_size']
+
+ def random_on_canonical_transform(self, image, label, intrinsic, cam_model, to_random_ratio):
+ ori_h, ori_w, _ = image.shape
+ ori_focal = (intrinsic[0] + intrinsic[1]) / 2.0
+
+ to_canonical_ratio = self.canonical_focal / ori_focal
+ to_scale_ratio = to_random_ratio
+ resize_ratio = to_canonical_ratio * to_random_ratio
+ reshape_h = int(ori_h * resize_ratio + 0.5)
+ reshape_w = int(ori_w * resize_ratio + 0.5)
+
+ image = cv2.resize(image, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ if intrinsic is not None:
+ intrinsic = [self.canonical_focal, self.canonical_focal, intrinsic[2]*resize_ratio, intrinsic[3]*resize_ratio]
+ if label is not None:
+ # number of other labels may be less than that of image
+ #label = cv2.resize(label, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+ label = resize_depth_preserve(label, (reshape_h, reshape_w))
+ # scale the label and camera intrinsics
+ label = label / to_scale_ratio
+
+ if cam_model is not None:
+ # Should not directly resize the cam_model.
+ # Camera model should be resized in 'to canonical' stage, while it holds in 'random resizing' stage.
+ # cam_model = cv2.resize(cam_model, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ cam_model = build_camera_model(reshape_h, reshape_w, intrinsic)
+ return image, label, intrinsic, cam_model, to_scale_ratio
+
+ def random_on_crop_transform(self, image, label, intrinsic, cam_model, to_random_ratio):
+ ori_h, ori_w, _ = image.shape
+ crop_h, crop_w = self.crop_size
+ ori_focal = (intrinsic[0] + intrinsic[1]) / 2.0
+
+ to_canonical_ratio = self.canonical_focal / ori_focal
+
+ # random resize based on the last crop size
+ proposal_reshape_h = int(crop_h * to_random_ratio + 0.5)
+ proposal_reshape_w = int(crop_w * to_random_ratio + 0.5)
+ resize_ratio_h = proposal_reshape_h / ori_h
+ resize_ratio_w = proposal_reshape_w / ori_w
+ resize_ratio = min(resize_ratio_h, resize_ratio_w) # resize based on the long edge
+ reshape_h = int(ori_h * resize_ratio + 0.5)
+ reshape_w = int(ori_w * resize_ratio + 0.5)
+
+ to_scale_ratio = resize_ratio / to_canonical_ratio
+
+ image = cv2.resize(image, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ if intrinsic is not None:
+ intrinsic = [self.canonical_focal, self.canonical_focal, intrinsic[2]*resize_ratio, intrinsic[3]*resize_ratio]
+ if label is not None:
+ # number of other labels may be less than that of image
+ # label = cv2.resize(label, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+ label = resize_depth_preserve(label, (reshape_h, reshape_w))
+ # scale the label and camera intrinsics
+ label = label / to_scale_ratio
+
+ if cam_model is not None:
+ # Should not directly resize the cam_model.
+ # Camera model should be resized in 'to canonical' stage, while it holds in 'random resizing' stage.
+ # cam_model = cv2.resize(cam_model, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR)
+ cam_model = build_camera_model(reshape_h, reshape_w, intrinsic)
+ return image, label, intrinsic, cam_model, to_scale_ratio
+
+ def __call__(self, images, labels, intrinsics, cam_models=None, normals=None, other_labels=None, transform_paras=None):
+ assert len(images[0].shape) == 3 and len(labels[0].shape) == 2
+ assert labels[0].dtype == np.float
+ target_focal = (intrinsics[0][0] + intrinsics[0][1]) / 2.0
+ target_to_canonical_ratio = self.canonical_focal / target_focal
+ target_img_shape = images[0].shape
+ to_random_ratio = random.uniform(self.ratio_range[0], self.ratio_range[1])
+ to_scale_ratio = 0
+ for i in range(len(images)):
+ img = images[i]
+ label = labels[i] if i < len(labels) else None
+ intrinsic = intrinsics[i] if i < len(intrinsics) else None
+ cam_model = cam_models[i] if cam_models is not None and i < len(cam_models) else None
+ img, label, intrinsic, cam_model, to_scale_ratio = self.random_on_canonical_transform(
+ img, label, intrinsic, cam_model, to_random_ratio)
+
+ images[i] = img
+ if label is not None:
+ labels[i] = label
+ if intrinsic is not None:
+ intrinsics[i] = intrinsic
+ if cam_model is not None:
+ cam_models[i] = cam_model
+ if normals != None:
+ reshape_h, reshape_w, _ = images[0].shape
+ for i, normal in enumerate(normals):
+ normals[i] = cv2.resize(normal, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+
+ if other_labels != None:
+ # other labels are like semantic segmentations, instance segmentations, instance planes segmentations...
+ # resize_ratio = target_to_canonical_ratio * to_random_ratio
+ # reshape_h = int(target_img_shape[0] * resize_ratio + 0.5)
+ # reshape_w = int(target_img_shape[1] * resize_ratio + 0.5)
+ reshape_h, reshape_w, _ = images[0].shape
+ for i, other_label_i in enumerate(other_labels):
+ other_labels[i] = cv2.resize(other_label_i, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST)
+
+ if transform_paras is not None:
+ transform_paras.update(label_scale_factor = 1.0/to_scale_ratio)
+
+ return images, labels, intrinsics, cam_models, normals, other_labels, transform_paras
+
+
+
+def build_camera_model(H : int, W : int, intrinsics : list) -> np.array:
+ """
+ Encode the camera intrinsic parameters (focal length and principle point) to a 4-channel map.
+ """
+ fx, fy, u0, v0 = intrinsics
+ f = (fx + fy) / 2.0
+ # principle point location
+ x_row = np.arange(0, W).astype(np.float32)
+ x_row_center_norm = (x_row - u0) / W
+ x_center = np.tile(x_row_center_norm, (H, 1)) # [H, W]
+
+ y_col = np.arange(0, H).astype(np.float32)
+ y_col_center_norm = (y_col - v0) / H
+ y_center = np.tile(y_col_center_norm, (W, 1)).T
+
+ # FoV
+ fov_x = np.arctan(x_center / (f / W))
+ fov_y = np.arctan(y_center/ (f / H))
+
+ cam_model = np.stack([x_center, y_center, fov_x, fov_y], axis=2)
+ return cam_model
+
+
+if __name__ == '__main__':
+ img = cv2.imread('/mnt/mldb/raw/62b3ed3455e805efcb28c74b/NuScenes/data_test/samples/CAM_FRONT/n008-2018-08-01-15-34-25-0400__CAM_FRONT__1533152214512404.jpg', -1)
+ H, W, _ = img.shape
+ label = img[:, :, 0]
+ intrinsic = [1000, 1000, W//2, H//2]
+ for i in range(20):
+ weather_aug = Weather(prob=1.0)
+ img_aug, label, intrinsic, cam_model, ref_images, transform_paras = weather_aug([img, ], [label,], [intrinsic,])
+ cv2.imwrite(f'test_aug_{i}.jpg', img_aug[0])
+
+ print('Done')
diff --git a/training/mono/utils/unproj_pcd.py b/training/mono/utils/unproj_pcd.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f2b67167c138a12884b36bdb18cc86cbdca2de6
--- /dev/null
+++ b/training/mono/utils/unproj_pcd.py
@@ -0,0 +1,82 @@
+import numpy as np
+import torch
+from plyfile import PlyData, PlyElement
+import cv2
+
+def get_pcd_base(H, W, u0, v0, focal_length):
+ x_row = np.arange(0, W)
+ x = np.tile(x_row, (H, 1))
+ x = x.astype(np.float32)
+ u_m_u0 = x - u0
+
+ y_col = np.arange(0, H) # y_col = np.arange(0, height)
+ y = np.tile(y_col, (W, 1)).T
+ y = y.astype(np.float32)
+ v_m_v0 = y - v0
+
+ x = u_m_u0 / focal_length
+ y = v_m_v0 / focal_length
+ z = np.ones_like(x)
+ pw = np.stack([x, y, z], 2) # [h, w, c]
+ return pw
+
+def reconstruct_pcd(depth, focal_length, u0, v0, pcd_base=None, mask=None):
+ if type(depth) == torch.__name__:
+ depth = depth.cpu().numpy().squeeze()
+ depth = cv2.medianBlur(depth, 5)
+ if pcd_base is None:
+ H, W = depth.shape
+ pcd_base = get_pcd_base(H, W, u0, v0, focal_length)
+ pcd = depth[:, :, None] * pcd_base
+ if mask:
+ pcd[mask] = 0
+ return pcd
+
+
+def save_point_cloud(pcd, rgb, filename, binary=True):
+ """Save an RGB point cloud as a PLY file.
+ :paras
+ @pcd: Nx3 matrix, the XYZ coordinates
+ @rgb: NX3 matrix, the rgb colors for each 3D point
+ """
+ assert pcd.shape[0] == rgb.shape[0]
+
+ if rgb is None:
+ gray_concat = np.tile(np.array([128], dtype=np.uint8), (pcd.shape[0], 3))
+ points_3d = np.hstack((pcd, gray_concat))
+ else:
+ points_3d = np.hstack((pcd, rgb))
+ python_types = (float, float, float, int, int, int)
+ npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'),
+ ('blue', 'u1')]
+ if binary is True:
+ # Format into NumPy structured array
+ vertices = []
+ for row_idx in range(points_3d.shape[0]):
+ cur_point = points_3d[row_idx]
+ vertices.append(tuple(dtype(point) for dtype, point in zip(python_types, cur_point)))
+ vertices_array = np.array(vertices, dtype=npy_types)
+ el = PlyElement.describe(vertices_array, 'vertex')
+
+ # Write
+ PlyData([el]).write(filename)
+ else:
+ x = np.squeeze(points_3d[:, 0])
+ y = np.squeeze(points_3d[:, 1])
+ z = np.squeeze(points_3d[:, 2])
+ r = np.squeeze(points_3d[:, 3])
+ g = np.squeeze(points_3d[:, 4])
+ b = np.squeeze(points_3d[:, 5])
+
+ ply_head = 'ply\n' \
+ 'format ascii 1.0\n' \
+ 'element vertex %d\n' \
+ 'property float x\n' \
+ 'property float y\n' \
+ 'property float z\n' \
+ 'property uchar red\n' \
+ 'property uchar green\n' \
+ 'property uchar blue\n' \
+ 'end_header' % r.shape[0]
+ # ---- Save ply data to disk
+ np.savetxt(filename, np.column_stack((x, y, z, r, g, b)), fmt="%d %d %d %d %d %d", header=ply_head, comments='')
diff --git a/training/mono/utils/visualization.py b/training/mono/utils/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ed4c734981b416396c111fa0615bc0f4a6e8a7d
--- /dev/null
+++ b/training/mono/utils/visualization.py
@@ -0,0 +1,209 @@
+import matplotlib.pyplot as plt
+import os, cv2
+import numpy as np
+from mono.utils.transform import gray_to_colormap
+import shutil
+import glob
+from mono.utils.running import main_process
+import torch
+from html4vision import Col, imagetable
+
+def save_raw_imgs(
+ pred: torch.tensor,
+ rgb: torch.tensor,
+ filename: str,
+ save_dir: str,
+ scale: float=1000.0,
+ target: torch.tensor=None,
+ ):
+ """
+ Save raw GT, predictions, RGB in the same file.
+ """
+ cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_rgb.jpg'), rgb)
+ cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_gt.png'), (pred*scale).astype(np.uint16))
+ if target is not None:
+ cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_gt.png'), (target*scale).astype(np.uint16))
+
+def save_normal_val_imgs(
+ iter: int,
+ pred: torch.tensor,
+ #targ: torch.tensor,
+ #rgb: torch.tensor,
+ filename: str,
+ save_dir: str,
+ tb_logger=None,
+ mask=None,
+ ):
+ """
+ Save GT, predictions, RGB in the same file.
+ """
+ mean = np.array([123.675, 116.28, 103.53])[np.newaxis, np.newaxis, :]
+ std= np.array([58.395, 57.12, 57.375])[np.newaxis, np.newaxis, :]
+ pred = pred.squeeze()
+
+ # if pred.size(0) == 3:
+ # pred = pred.permute(1,2,0)
+ # pred_color = vis_surface_normal(pred, mask)
+
+ # #save one image only
+ # plt.imsave(os.path.join(save_dir, filename[:-4]+'.jpg'), pred_color)
+
+ targ = targ.squeeze()
+ rgb = rgb.squeeze()
+
+ if pred.size(0) == 3:
+ pred = pred.permute(1,2,0)
+ if targ.size(0) == 3:
+ targ = targ.permute(1,2,0)
+ if rgb.size(0) == 3:
+ rgb = rgb.permute(1,2,0)
+
+ pred_color = vis_surface_normal(pred, mask)
+ targ_color = vis_surface_normal(targ, mask)
+ rgb_color = ((rgb.cpu().numpy() * std) + mean).astype(np.uint8)
+
+ try:
+ cat_img = np.concatenate([rgb_color, pred_color, targ_color], axis=0)
+ except:
+ pred_color = cv2.resize(pred_color, (rgb.shape[1], rgb.shape[0]))
+ targ_color = cv2.resize(targ_color, (rgb.shape[1], rgb.shape[0]))
+ cat_img = np.concatenate([rgb_color, pred_color, targ_color], axis=0)
+
+ plt.imsave(os.path.join(save_dir, filename[:-4]+'_merge.jpg'), cat_img)
+ # cv2.imwrite(os.path.join(save_dir, filename[:-4]+'.jpg'), pred_color)
+ # save to tensorboard
+ if tb_logger is not None:
+ tb_logger.add_image(f'{filename[:-4]}_merge.jpg', cat_img.transpose((2, 0, 1)), iter)
+
+
+
+
+def save_val_imgs(
+ iter: int,
+ pred: torch.tensor,
+ target: torch.tensor,
+ rgb: torch.tensor,
+ filename: str,
+ save_dir: str,
+ tb_logger=None
+ ):
+ """
+ Save GT, predictions, RGB in the same file.
+ """
+ rgb, pred_scale, target_scale, pred_color, target_color, max_scale = get_data_for_log(pred, target, rgb)
+ rgb = rgb.transpose((1, 2, 0))
+ # plt.imsave(os.path.join(save_dir, filename[:-4]+'_rgb.jpg'), rgb)
+ # plt.imsave(os.path.join(save_dir, filename[:-4]+'_pred.png'), pred_scale, cmap='rainbow')
+ # plt.imsave(os.path.join(save_dir, filename[:-4]+'_gt.png'), target_scale, cmap='rainbow')
+ cat_img = np.concatenate([rgb, pred_color, target_color], axis=0)
+ plt.imsave(os.path.join(save_dir, filename[:-4]+'_merge.jpg'), cat_img)
+
+ # save to tensorboard
+ if tb_logger is not None:
+ # tb_logger.add_image(f'{filename[:-4]}_rgb.jpg', rgb, iter)
+ # tb_logger.add_image(f'{filename[:-4]}_pred.jpg', gray_to_colormap(pred_scale).transpose((2, 0, 1)), iter)
+ # tb_logger.add_image(f'{filename[:-4]}_gt.jpg', gray_to_colormap(target_scale).transpose((2, 0, 1)), iter)
+ tb_logger.add_image(f'{filename[:-4]}_merge.jpg', cat_img.transpose((2, 0, 1)), iter)
+ return max_scale
+
+def get_data_for_log(pred: torch.tensor, target: torch.tensor, rgb: torch.tensor):
+ mean = np.array([123.675, 116.28, 103.53])[:, np.newaxis, np.newaxis]
+ std= np.array([58.395, 57.12, 57.375])[:, np.newaxis, np.newaxis]
+
+ pred = pred.squeeze().cpu().numpy()
+ target = target.squeeze().cpu().numpy()
+ rgb = rgb.squeeze().cpu().numpy()
+
+ pred[pred<0] = 0
+ target[target<0] = 0
+ #max_scale = max(pred.max(), target.max())
+ max_scale = min(2.0 * target.max(), pred.max())
+ pred[pred > max_scale] = max_scale
+
+ pred_scale = (pred/max_scale * 10000).astype(np.uint16)
+ target_scale = (target/max_scale * 10000).astype(np.uint16)
+ pred_color = gray_to_colormap(pred, max_value=max_scale)
+ target_color = gray_to_colormap(target, max_value=max_scale)
+
+ dilate = True
+ if dilate == True:
+ k=np.ones((3,3),np.uint8)
+ target_color=cv2.dilate(target_color,k,iterations=1)
+
+ pred_color = cv2.resize(pred_color, (rgb.shape[2], rgb.shape[1]))
+ target_color = cv2.resize(target_color, (rgb.shape[2], rgb.shape[1]))
+
+ rgb = ((rgb * std) + mean).astype(np.uint8)
+ return rgb, pred_scale, target_scale, pred_color, target_color, max_scale
+
+
+def create_html(name2path, save_path='index.html', size=(256, 384)):
+ # table description
+ cols = []
+ for k, v in name2path.items():
+ col_i = Col('img', k, v) # specify image content for column
+ cols.append(col_i)
+ # html table generation
+ imagetable(cols, out_file=save_path, imsize=size)
+
+
+def visual_train_data(gt_depth, rgb, filename, wkdir, replace=False, pred=None):
+ gt_depth = gt_depth.cpu().squeeze().numpy()
+ rgb = rgb.cpu().squeeze().numpy()
+
+ mean = np.array([123.675, 116.28, 103.53])[:, np.newaxis, np.newaxis]
+ std= np.array([58.395, 57.12, 57.375])[:, np.newaxis, np.newaxis]
+ mask = gt_depth > 0
+
+ rgb = ((rgb * std) + mean).astype(np.uint8).transpose((1, 2, 0))
+ gt_vis = gray_to_colormap(gt_depth)
+ if replace:
+ rgb[mask] = gt_vis[mask]
+
+ if pred is not None:
+ pred_depth = pred.detach().cpu().squeeze().numpy()
+ pred_vis = gray_to_colormap(pred_depth)
+
+ merge = np.concatenate([rgb, gt_vis, pred_vis], axis=0)
+
+ save_path = os.path.join(wkdir, 'test_train', filename)
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ plt.imsave(save_path, merge)
+
+
+def create_dir_for_validate_meta(work_dir, iter_id):
+ curr_folders = glob.glob(work_dir + '/online_val/*0')
+ curr_folders = [i for i in curr_folders if os.path.isdir(i)]
+ if len(curr_folders) > 8:
+ curr_folders.sort()
+ del_foler = curr_folders.pop(0)
+ print(del_foler)
+ if main_process():
+ # only rank==0 do it
+ if os.path.exists(del_foler):
+ shutil.rmtree(del_foler)
+ if os.path.exists(del_foler + '.html'):
+ os.remove(del_foler + '.html')
+
+ save_val_meta_data_dir = os.path.join(work_dir, 'online_val', '%08d'%iter_id)
+ os.makedirs(save_val_meta_data_dir, exist_ok=True)
+ return save_val_meta_data_dir
+
+
+def vis_surface_normal(normal: torch.tensor, mask: torch.tensor=None) -> np.array:
+ """
+ Visualize surface normal. Transfer surface normal value from [-1, 1] to [0, 255]
+ Aargs:
+ normal (torch.tensor, [h, w, 3]): surface normal
+ mask (torch.tensor, [h, w]): valid masks
+ """
+ normal = normal.cpu().numpy().squeeze()
+ n_img_L2 = np.sqrt(np.sum(normal ** 2, axis=2, keepdims=True))
+ n_img_norm = normal / (n_img_L2 + 1e-8)
+ normal_vis = n_img_norm * 127
+ normal_vis += 128
+ normal_vis = normal_vis.astype(np.uint8)
+ if mask is not None:
+ mask = mask.cpu().numpy().squeeze()
+ normal_vis[~mask] = 0
+ return normal_vis
\ No newline at end of file
diff --git a/training/mono/utils/weather_aug_utils.py b/training/mono/utils/weather_aug_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec4aa31a666878782f75fed7868436a7b9c08332
--- /dev/null
+++ b/training/mono/utils/weather_aug_utils.py
@@ -0,0 +1,872 @@
+
+# import glob
+import cv2 as cv2
+import numpy as np
+# import matplotlib.pyplot as plt
+import random
+import math
+
+
+###################### HLS #############################
+
+def hls(image,src='RGB'):
+ verify_image(image)
+ if(is_list(image)):
+ image_HLS=[]
+ image_list=image
+ for img in image_list:
+ eval('image_HLS.append(cv2.cvtColor(img,cv2.COLOR_'+src.upper()+'2HLS))')
+ else:
+ image_HLS = eval('cv2.cvtColor(image,cv2.COLOR_'+src.upper()+'2HLS)')
+ return image_HLS
+
+def hue(image,src='RGB'):
+ verify_image(image)
+ if(is_list(image)):
+ image_Hue=[]
+ image_list=image
+ for img in image_list:
+ image_Hue.append(hls(img,src)[:,:,0])
+ else:
+ image_Hue= hls(image,src)[:,:,0]
+ return image_Hue
+
+def lightness(image,src='RGB'):
+ verify_image(image)
+ if(is_list(image)):
+ image_lightness=[]
+ image_list=image
+ for img in image_list:
+ image_lightness.append(hls(img,src)[:,:,1])
+ else:
+ image_lightness= hls(image,src)[:,:,1]
+ return image_lightness
+
+def saturation(image,src='RGB'):
+ verify_image(image)
+ if(is_list(image)):
+ image_saturation=[]
+ image_list=image
+ for img in image_list:
+ image_saturation.append(hls(img,src)[:,:,2])
+ else:
+ image_saturation= hls(image,src)[:,:,2]
+ return image_saturation
+
+###################### HSV #############################
+
+def hsv(image,src='RGB'):
+ verify_image(image)
+ if(is_list(image)):
+ image_HSV=[]
+ image_list=image
+ for img in image_list:
+ eval('image_HSV.append(cv2.cvtColor(img,cv2.COLOR_'+src.upper()+'2HSV))')
+ else:
+ image_HSV = eval('cv2.cvtColor(image,cv2.COLOR_'+src.upper()+'2HSV)')
+ return image_HSV
+
+def value(image,src='RGB'):
+ verify_image(image)
+ if(is_list(image)):
+ image_value=[]
+ image_list=image
+ for img in image_list:
+ image_value.append(hsv(img,src)[:,:,2])
+ else:
+ image_value= hsv(image,src)[:,:,2]
+ return image_value
+
+###################### BGR #############################
+
+def bgr(image, src='RGB'):
+ verify_image(image)
+ if(is_list(image)):
+ image_BGR=[]
+ image_list=image
+ for img in image_list:
+ eval('image_BGR.append(cv2.cvtColor(img,cv2.COLOR_'+src.upper()+'2BGR))')
+ else:
+ image_BGR= eval('cv2.cvtColor(image,cv2.COLOR_'+src.upper()+'2BGR)')
+ return image_BGR
+
+###################### RGB #############################
+def rgb(image, src='BGR'):
+ verify_image(image)
+ if(is_list(image)):
+ image_RGB=[]
+ image_list=image
+ for img in image_list:
+ eval('image_RGB.append(cv2.cvtColor(img,cv2.COLOR_'+src.upper()+'2RGB))')
+ else:
+ image_RGB= eval('cv2.cvtColor(image,cv2.COLOR_'+src.upper()+'2RGB)')
+ return image_RGB
+
+def red(image,src='BGR'):
+ verify_image(image)
+ if(is_list(image)):
+ image_red=[]
+ image_list=image
+ for img in image_list:
+ i= eval('cv2.cvtColor(img,cv2.COLOR_'+src.upper()+'2RGB)')
+ image_red.append(i[:,:,0])
+ else:
+ image_red= eval('cv2.cvtColor(image,cv2.COLOR_'+src.upper()+'2RGB)[:,:,0]')
+ return image_red
+
+def green(image,src='BGR'):
+ verify_image(image)
+ if(is_list(image)):
+ image_green=[]
+ image_list=image
+ for img in image_list:
+ i= eval('cv2.cvtColor(img,cv2.COLOR_'+src.upper()+'2RGB)')
+ image_green.append(i[:,:,1])
+ else:
+ image_green= eval('cv2.cvtColor(image,cv2.COLOR_'+src.upper()+'2RGB)[:,:,1]')
+ return image_green
+
+def blue(image,src='BGR'):
+ verify_image(image)
+ if(is_list(image)):
+ image_blue=[]
+ image_list=image
+ for img in image_list:
+ i=eval('cv2.cvtColor(img,cv2.COLOR_'+src.upper()+'2RGB)')
+ image_blue.append(i[:,:,2])
+ else:
+ image_blue= eval('cv2.cvtColor(image,cv2.COLOR_'+src.upper()+'2RGB)[:,:,2]')
+ return image_blue
+
+err_not_np_img= "not a numpy array or list of numpy array"
+err_img_arr_empty="Image array is empty"
+err_row_zero="No. of rows can't be <=0"
+err_column_zero="No. of columns can't be <=0"
+err_invalid_size="Not a valid size tuple (x,y)"
+err_caption_array_count="Caption array length doesn't matches the image array length"
+
+def is_numpy_array(x):
+
+ return isinstance(x, np.ndarray)
+def is_tuple(x):
+ return type(x) is tuple
+def is_list(x):
+ return type(x) is list
+def is_numeric(x):
+ return type(x) is int
+def is_numeric_list_or_tuple(x):
+ for i in x:
+ if not is_numeric(i):
+ return False
+ return True
+
+err_brightness_coeff="brightness coeff can only be between 0.0 to 1.0"
+err_darkness_coeff="darkness coeff can only be between 0.0 to 1.0"
+
+def change_light(image, coeff):
+ image_HLS = cv2.cvtColor(image,cv2.COLOR_RGB2HLS) ## Conversion to HLS
+ image_HLS = np.array(image_HLS, dtype = np.float64)
+ image_HLS[:,:,1] = image_HLS[:,:,1]*coeff ## scale pixel values up or down for channel 1(Lightness)
+ if(coeff>1):
+ image_HLS[:,:,1][image_HLS[:,:,1]>255] = 255 ##Sets all values above 255 to 255
+ else:
+ image_HLS[:,:,1][image_HLS[:,:,1]<0]=0
+ image_HLS = np.array(image_HLS, dtype = np.uint8)
+ image_RGB = cv2.cvtColor(image_HLS,cv2.COLOR_HLS2RGB) ## Conversion to RGB
+ return image_RGB
+
+def verify_image(image):
+ if is_numpy_array(image):
+ pass
+ elif(is_list(image)):
+ image_list=image
+ for img in image_list:
+ if not is_numpy_array(img):
+ raise Exception(err_not_np_img)
+ else:
+ raise Exception(err_not_np_img)
+
+def brighten(image, brightness_coeff=-1): ##function to brighten the image
+ verify_image(image)
+ if(brightness_coeff!=-1):
+ if(brightness_coeff<0.0 or brightness_coeff>1.0):
+ raise Exception(err_brightness_coeff)
+ if(is_list(image)):
+ image_RGB=[]
+ image_list=image
+ for img in image_list:
+ if(brightness_coeff==-1):
+ brightness_coeff_t=1+ random.uniform(0,1) ## coeff between 1.0 and 1.5
+ else:
+ brightness_coeff_t=1+ brightness_coeff ## coeff between 1.0 and 2.0
+ image_RGB.append(change_light(img,brightness_coeff_t))
+ else:
+ if(brightness_coeff==-1):
+ brightness_coeff_t=1+ random.uniform(0,1) ## coeff between 1.0 and 1.5
+ else:
+ brightness_coeff_t=1+ brightness_coeff ## coeff between 1.0 and 2.0
+ image_RGB= change_light(image,brightness_coeff_t)
+ return image_RGB
+
+def darken(image, darkness_coeff=-1): ##function to darken the image
+ verify_image(image)
+ if(darkness_coeff!=-1):
+ if(darkness_coeff<0.0 or darkness_coeff>1.0):
+ raise Exception(err_darkness_coeff)
+
+ if(is_list(image)):
+ image_RGB=[]
+ image_list=image
+ for img in image_list:
+ if(darkness_coeff==-1):
+ darkness_coeff_t=1- random.uniform(0,1)
+ else:
+ darkness_coeff_t=1- darkness_coeff
+ image_RGB.append(change_light(img,darkness_coeff_t))
+ else:
+ if(darkness_coeff==-1):
+ darkness_coeff_t=1- random.uniform(0,1)
+ else:
+ darkness_coeff_t=1- darkness_coeff
+ image_RGB= change_light(image,darkness_coeff_t)
+ return image_RGB
+
+
+def random_brightness(image):
+ verify_image(image)
+
+ if(is_list(image)):
+ image_RGB=[]
+ image_list=image
+ for img in image_list:
+ random_brightness_coefficient = 2* np.random.uniform(0,1) ## generates value between 0.0 and 2.0
+ image_RGB.append(change_light(img,random_brightness_coefficient))
+ else:
+ random_brightness_coefficient = 2* np.random.uniform(0,1) ## generates value between 0.0 and 2.0
+ image_RGB= change_light(image,random_brightness_coefficient)
+ return image_RGB
+
+err_shadow_count="only 1-10 shadows can be introduced in an image"
+err_invalid_rectangular_roi="Rectangular ROI dimensions are not valid"
+err_shadow_dimension="polygons with dim<3 dont exist and >10 take time to plot"
+
+def generate_shadow_coordinates(imshape, no_of_shadows, rectangular_roi, shadow_dimension):
+ vertices_list=[]
+ x1=rectangular_roi[0]
+ y1=rectangular_roi[1]
+ x2=rectangular_roi[2]
+ y2=rectangular_roi[3]
+ for index in range(no_of_shadows):
+ vertex=[]
+ for dimensions in range(shadow_dimension): ## Dimensionality of the shadow polygon
+ vertex.append((random.randint(x1, x2),random.randint(y1, y2)))
+ vertices = np.array([vertex], dtype=np.int32) ## single shadow vertices
+ vertices_list.append(vertices)
+ return vertices_list ## List of shadow vertices
+
+def shadow_process(image,no_of_shadows,x1,y1,x2,y2, shadow_dimension):
+ image_HLS = cv2.cvtColor(image,cv2.COLOR_RGB2HLS) ## Conversion to HLS
+ mask = np.zeros_like(image)
+ imshape = image.shape
+ vertices_list= generate_shadow_coordinates(imshape, no_of_shadows,(x1,y1,x2,y2), shadow_dimension) #3 getting list of shadow vertices
+ for vertices in vertices_list:
+ cv2.fillPoly(mask, vertices, 255) ## adding all shadow polygons on empty mask, single 255 denotes only red channel
+ image_HLS[:,:,1][mask[:,:,0]==255] = image_HLS[:,:,1][mask[:,:,0]==255]*0.5 ## if red channel is hot, image's "Lightness" channel's brightness is lowered
+ image_RGB = cv2.cvtColor(image_HLS,cv2.COLOR_HLS2RGB) ## Conversion to RGB
+ return image_RGB
+
+def add_shadow(image,no_of_shadows=1,rectangular_roi=(-1,-1,-1,-1), shadow_dimension=5):## ROI:(top-left x1,y1, bottom-right x2,y2), shadow_dimension=no. of sides of polygon generated
+ verify_image(image)
+ if not(is_numeric(no_of_shadows) and no_of_shadows>=1 and no_of_shadows<=10):
+ raise Exception(err_shadow_count)
+ if not(is_numeric(shadow_dimension) and shadow_dimension>=3 and shadow_dimension<=10):
+ raise Exception(err_shadow_dimension)
+ if is_tuple(rectangular_roi) and is_numeric_list_or_tuple(rectangular_roi) and len(rectangular_roi)==4:
+ x1=rectangular_roi[0]
+ y1=rectangular_roi[1]
+ x2=rectangular_roi[2]
+ y2=rectangular_roi[3]
+ else:
+ raise Exception(err_invalid_rectangular_roi)
+ if rectangular_roi==(-1,-1,-1,-1):
+ x1=0
+
+ if(is_numpy_array(image)):
+ y1=image.shape[0]//2
+ x2=image.shape[1]
+ y2=image.shape[0]
+ else:
+ y1=image[0].shape[0]//2
+ x2=image[0].shape[1]
+ y2=image[0].shape[0]
+
+ elif x1==-1 or y1==-1 or x2==-1 or y2==-1 or x2<=x1 or y2<=y1:
+ raise Exception(err_invalid_rectangular_roi)
+ if(is_list(image)):
+ image_RGB=[]
+ image_list=image
+ for img in image_list:
+ output=shadow_process(img,no_of_shadows,x1,y1,x2,y2, shadow_dimension)
+ image_RGB.append(output)
+ else:
+ output=shadow_process(image,no_of_shadows,x1,y1,x2,y2, shadow_dimension)
+ image_RGB = output
+
+ return image_RGB
+
+err_snow_coeff="Snow coeff can only be between 0 and 1"
+def snow_process(image,snow_coeff):
+ image_HLS = cv2.cvtColor(image,cv2.COLOR_RGB2HLS) ## Conversion to HLS
+ image_HLS = np.array(image_HLS, dtype = np.float64)
+ brightness_coefficient = 2.5
+ imshape = image.shape
+ snow_point=snow_coeff ## increase this for more snow
+ image_HLS[:,:,1][image_HLS[:,:,1]255] = 255 ##Sets all values above 255 to 255
+ image_HLS = np.array(image_HLS, dtype = np.uint8)
+ image_RGB = cv2.cvtColor(image_HLS,cv2.COLOR_HLS2RGB) ## Conversion to RGB
+ return image_RGB
+
+def add_snow(image, snow_coeff=-1):
+ verify_image(image)
+ if(snow_coeff!=-1):
+ if(snow_coeff<0.0 or snow_coeff>1.0):
+ raise Exception(err_snow_coeff)
+ else:
+ snow_coeff=random.uniform(0,1)
+ snow_coeff*=255/2
+ snow_coeff+=255/3
+ if(is_list(image)):
+ image_RGB=[]
+ image_list=image
+ for img in image_list:
+ output= snow_process(img,snow_coeff)
+ image_RGB.append(output)
+ else:
+ output= snow_process(image,snow_coeff)
+ image_RGB=output
+
+ return image_RGB
+
+err_rain_slant="Numeric value between -20 and 20 is allowed"
+err_rain_width="Width value between 1 and 5 is allowed"
+err_rain_length="Length value between 0 and 100 is allowed"
+def generate_random_lines(imshape,slant,drop_length,rain_type):
+ drops=[]
+ area=imshape[0]*imshape[1]
+ no_of_drops=area//600
+
+ if rain_type.lower()=='drizzle':
+ no_of_drops=area//770
+ drop_length=10
+ elif rain_type.lower()=='heavy':
+ drop_length=30
+ elif rain_type.lower()=='torrential':
+ no_of_drops=area//500
+ drop_length=60
+
+ for i in range(no_of_drops): ## If You want heavy rain, try increasing this
+ if slant<0:
+ x= np.random.randint(slant,imshape[1])
+ else:
+ x= np.random.randint(0,imshape[1]-slant)
+ y= np.random.randint(0,imshape[0]-drop_length)
+ drops.append((x,y))
+ return drops,drop_length
+
+def rain_process(image,slant,drop_length,drop_color,drop_width,rain_drops):
+ imshape = image.shape
+ image_t= image.copy()
+ for rain_drop in rain_drops:
+ cv2.line(image_t,(rain_drop[0],rain_drop[1]),(rain_drop[0]+slant,rain_drop[1]+drop_length),drop_color,drop_width)
+ image= cv2.blur(image_t,(7,7)) ## rainy view are blurry
+ brightness_coefficient = 0.7 ## rainy days are usually shady
+ image_HLS = hls(image) ## Conversion to HLS
+ image_HLS[:,:,1] = image_HLS[:,:,1]*brightness_coefficient ## scale pixel values down for channel 1(Lightness)
+ image_RGB= rgb(image_HLS,'hls') ## Conversion to RGB
+ return image_RGB
+
+##rain_type='drizzle','heavy','torrential'
+def add_rain(image,slant=-1,drop_length=20,drop_width=1,drop_color=(200,200,200),rain_type='None'): ## (200,200,200) a shade of gray
+ verify_image(image)
+ slant_extreme=slant
+ if not(is_numeric(slant_extreme) and (slant_extreme>=-20 and slant_extreme<=20)or slant_extreme==-1):
+ raise Exception(err_rain_slant)
+ if not(is_numeric(drop_width) and drop_width>=1 and drop_width<=5):
+ raise Exception(err_rain_width)
+ if not(is_numeric(drop_length) and drop_length>=0 and drop_length<=100):
+ raise Exception(err_rain_length)
+
+ if(is_list(image)):
+ image_RGB=[]
+ image_list=image
+ imshape = image[0].shape
+ if slant_extreme==-1:
+ slant= np.random.randint(-10,10) ##generate random slant if no slant value is given
+ rain_drops,drop_length= generate_random_lines(imshape,slant,drop_length,rain_type)
+ for img in image_list:
+ output= rain_process(img,slant_extreme,drop_length,drop_color,drop_width,rain_drops)
+ image_RGB.append(output)
+ else:
+ imshape = image.shape
+ if slant_extreme==-1:
+ slant= np.random.randint(-10,10) ##generate random slant if no slant value is given
+ rain_drops,drop_length= generate_random_lines(imshape,slant,drop_length,rain_type)
+ output= rain_process(image,slant_extreme,drop_length,drop_color,drop_width,rain_drops)
+ image_RGB=output
+
+ return image_RGB
+
+err_fog_coeff="Fog coeff can only be between 0 and 1"
+def add_blur(image, x,y,hw,fog_coeff):
+ overlay= image.copy()
+ output= image.copy()
+ alpha= 0.08*fog_coeff
+ rad= hw//2
+ point=(x+hw//2, y+hw//2)
+ cv2.circle(overlay,point, int(rad), (255,255,255), -1)
+ cv2.addWeighted(overlay, alpha, output, 1 -alpha ,0, output)
+ return output
+
+def generate_random_blur_coordinates(imshape,hw):
+ blur_points=[]
+ midx= imshape[1]//2-2*hw
+ midy= imshape[0]//2-hw
+ index=1
+ while(midx>-hw or midy>-hw):
+ for i in range(hw//10*index):
+ x= np.random.randint(midx,imshape[1]-midx-hw)
+ y= np.random.randint(midy,imshape[0]-midy-hw)
+ blur_points.append((x,y))
+ midx-=3*hw*imshape[1]//sum(imshape)
+ midy-=3*hw*imshape[0]//sum(imshape)
+ index+=1
+ return blur_points
+
+def add_fog(image, fog_coeff=-1):
+ verify_image(image)
+
+ if(fog_coeff!=-1):
+ if(fog_coeff<0.0 or fog_coeff>1.0):
+ raise Exception(err_fog_coeff)
+ if(is_list(image)):
+ image_RGB=[]
+ image_list=image
+ imshape = image[0].shape
+
+ for img in image_list:
+ if fog_coeff==-1:
+ fog_coeff_t=random.uniform(0.3,1)
+ else:
+ fog_coeff_t=fog_coeff
+ hw=int(imshape[1]//3*fog_coeff_t)
+ haze_list= generate_random_blur_coordinates(imshape,hw)
+ for haze_points in haze_list:
+ img= add_blur(img, haze_points[0],haze_points[1], hw,fog_coeff_t) ## adding all shadow polygons on empty mask, single 255 denotes only red channel
+ img = cv2.blur(img ,(hw//10,hw//10))
+ image_RGB.append(img)
+ else:
+ imshape = image.shape
+ if fog_coeff==-1:
+ fog_coeff_t=random.uniform(0.3,1)
+ else:
+ fog_coeff_t=fog_coeff
+ hw=int(imshape[1]//3*fog_coeff_t)
+ haze_list= generate_random_blur_coordinates(imshape,hw)
+ for haze_points in haze_list:
+ image= add_blur(image, haze_points[0],haze_points[1], hw,fog_coeff_t)
+ image = cv2.blur(image ,(hw//10,hw//10))
+ image_RGB = image
+
+ return image_RGB
+
+def generate_gravel_patch(rectangular_roi):
+ x1=rectangular_roi[0]
+ y1=rectangular_roi[1]
+ x2=rectangular_roi[2]
+ y2=rectangular_roi[3]
+ gravels=[]
+ area= abs((x2-x1)*(y2-y1))
+ for i in range((int)(area//10)):
+ x= np.random.randint(x1,x2)
+ y= np.random.randint(y1,y2)
+ gravels.append((x,y))
+ return gravels
+
+def gravel_process(image,x1,x2,y1,y2,no_of_patches):
+ x=image.shape[1]
+ y=image.shape[0]
+ rectangular_roi_default=[]
+ for i in range(no_of_patches):
+ xx1=random.randint(x1, x2)
+ xx2=random.randint(x1, xx1)
+ yy1=random.randint(y1, y2)
+ yy2=random.randint(y1, yy1)
+ rectangular_roi_default.append((xx2,yy2,min(xx1,xx2+200),min(yy1,yy2+30)))
+ img_hls=hls(image)
+ for roi in rectangular_roi_default:
+ gravels= generate_gravel_patch(roi)
+ for gravel in gravels:
+ x=gravel[0]
+ y=gravel[1]
+ r=random.randint(1, 4)
+ r1=random.randint(0, 255)
+ img_hls[max(y-r,0):min(y+r,y),max(x-r,0):min(x+r,x),1]=r1
+ image_RGB= rgb(img_hls,'hls')
+ return image_RGB
+
+def add_gravel(image,rectangular_roi=(-1,-1,-1,-1), no_of_patches=8):
+ verify_image(image)
+ if is_tuple(rectangular_roi) and is_numeric_list_or_tuple(rectangular_roi) and len(rectangular_roi)==4:
+ x1=rectangular_roi[0]
+ y1=rectangular_roi[1]
+ x2=rectangular_roi[2]
+ y2=rectangular_roi[3]
+ else:
+ raise Exception(err_invalid_rectangular_roi)
+ if rectangular_roi==(-1,-1,-1,-1):
+ if(is_numpy_array(image)):
+ x1=0
+ y1=int(image.shape[0]*3/4)
+ x2=image.shape[1]
+ y2=image.shape[0]
+ else:
+ x1=0
+ y1=int(image[0].shape[0]*3/4)
+ x2=image[0].shape[1]
+ y2=image[0].shape[0]
+ elif x1==-1 or y1==-1 or x2==-1 or y2==-1 or x2<=x1 or y2<=y1:
+ raise Exception(err_invalid_rectangular_roi)
+ color=[0,255]
+ if(is_list(image)):
+ image_RGB=[]
+ image_list=image
+ for img in image_list:
+ output= gravel_process(img,x1,x2,y1,y2,no_of_patches)
+ image_RGB.append(output)
+ else:
+ output= gravel_process(image,x1,x2,y1,y2,no_of_patches)
+ image_RGB= output
+ return image_RGB
+
+err_flare_circle_count="Numeric value between 0 and 20 is allowed"
+def flare_source(image, point,radius,src_color):
+ overlay= image.copy()
+ output= image.copy()
+ num_times=radius//10
+ alpha= np.linspace(0.0,1,num= num_times)
+ rad= np.linspace(1,radius, num=num_times)
+ for i in range(num_times):
+ cv2.circle(overlay,point, int(rad[i]), src_color, -1)
+ alp=alpha[num_times-i-1]*alpha[num_times-i-1]*alpha[num_times-i-1]
+ cv2.addWeighted(overlay, alp, output, 1 -alp ,0, output)
+ return output
+
+def add_sun_flare_line(flare_center,angle,imshape):
+ x=[]
+ y=[]
+ i=0
+ for rand_x in range(0,imshape[1],10):
+ rand_y= math.tan(angle)*(rand_x-flare_center[0])+flare_center[1]
+ x.append(rand_x)
+ y.append(2*flare_center[1]-rand_y)
+ return x,y
+
+def add_sun_process(image, no_of_flare_circles,flare_center,src_radius,x,y,src_color):
+ overlay= image.copy()
+ output= image.copy()
+ imshape=image.shape
+ for i in range(no_of_flare_circles):
+ alpha=random.uniform(0.05,0.2)
+ r=random.randint(0, len(x)-1)
+ rad=random.randint(1, imshape[0]//100-2)
+ cv2.circle(overlay,(int(x[r]),int(y[r])), rad*rad*rad, (random.randint(max(src_color[0]-50,0), src_color[0]),random.randint(max(src_color[1]-50,0), src_color[1]),random.randint(max(src_color[2]-50,0), src_color[2])), -1)
+ cv2.addWeighted(overlay, alpha, output, 1 - alpha,0, output)
+ output= flare_source(output,(int(flare_center[0]),int(flare_center[1])),src_radius,src_color)
+ return output
+
+def add_sun_flare(image,flare_center=-1, angle=-1, no_of_flare_circles=8,src_radius=400, src_color=(255,255,255)):
+ verify_image(image)
+ if(angle!=-1):
+ angle=angle%(2*math.pi)
+ if not(no_of_flare_circles>=0 and no_of_flare_circles<=20):
+ raise Exception(err_flare_circle_count)
+ if(is_list(image)):
+ image_RGB=[]
+ image_list=image
+ imshape=image_list[0].shape
+ for img in image_list:
+ if(angle==-1):
+ angle_t=random.uniform(0,2*math.pi)
+ if angle_t==math.pi/2:
+ angle_t=0
+ else:
+ angle_t=angle
+ if flare_center==-1:
+ flare_center_t=(random.randint(0,imshape[1]),random.randint(0,imshape[0]//2))
+ else:
+ flare_center_t=flare_center
+ x,y= add_sun_flare_line(flare_center_t,angle_t,imshape)
+ output= add_sun_process(img, no_of_flare_circles,flare_center_t,src_radius,x,y,src_color)
+ image_RGB.append(output)
+ else:
+ imshape=image.shape
+ if(angle==-1):
+ angle_t=random.uniform(0,2*math.pi)
+ if angle_t==math.pi/2:
+ angle_t=0
+ else:
+ angle_t=angle
+ if flare_center==-1:
+ flare_center_t=(random.randint(0,imshape[1]),random.randint(0,imshape[0]//2))
+ else:
+ flare_center_t=flare_center
+ x,y= add_sun_flare_line(flare_center_t,angle_t,imshape)
+ output= add_sun_process(image, no_of_flare_circles,flare_center_t,src_radius,x,y,src_color)
+ image_RGB = output
+ return image_RGB
+
+err_speed_coeff="Speed coeff can only be between 0 and 1"
+def apply_motion_blur(image,count):
+ image_t=image.copy()
+ imshape=image_t.shape
+ size=15
+ kernel_motion_blur = np.zeros((size, size))
+ kernel_motion_blur[int((size-1)/2), :] = np.ones(size)
+ kernel_motion_blur = kernel_motion_blur / size
+ i= imshape[1]*3//4 - 10*count
+ while(i<=imshape[1]):
+ image_t[:,i:,:] = cv2.filter2D(image_t[:,i:,:], -1, kernel_motion_blur)
+ image_t[:,:imshape[1]-i,:] = cv2.filter2D(image_t[:,:imshape[1]-i,:], -1, kernel_motion_blur)
+ i+=imshape[1]//25-count
+ count+=1
+ image_RGB=image_t
+ return image_RGB
+
+def add_speed(image, speed_coeff=-1):
+ verify_image(image)
+ if(speed_coeff !=-1):
+ if(speed_coeff<0.0 or speed_coeff>1.0):
+ raise Exception(err_speed_coeff)
+ if(is_list(image)):
+ image_RGB=[]
+ image_list=image
+ for img in image_list:
+ if(speed_coeff==-1):
+ count_t=int(15*random.uniform(0,1))
+ else:
+ count_t=int(15*speed_coeff)
+ img=apply_motion_blur(img,count_t)
+ image_RGB.append(img)
+ else:
+ if(speed_coeff==-1):
+ count_t=int(15*random.uniform(0,1))
+ else:
+ count_t=int(15*speed_coeff)
+ image_RGB= apply_motion_blur(image,count_t)
+
+
+ return image_RGB
+
+
+
+def autumn_process(image):
+ image_t=image.copy()
+ imshape=image_t.shape
+ image_hls= hls(image_t)
+ step=8
+ aut_colors=[1,5,9,11]
+ col= aut_colors[random.randint(0,3)]
+ for i in range(0,imshape[1],step):
+ for j in range(0,imshape[0],step):
+ avg=np.average(image_hls[j:j+step,i:i+step,0])
+# print(avg)
+ if(avg >20 and avg< 100 and np.average(image[j:j+step,i:i+step,1])<100):
+ image_hls[j:j+step,i:i+step,0]= col
+ image_hls[j:j+step,i:i+step,2]=255
+ return rgb(image_hls,'hls')
+
+
+def add_autumn(image):
+ verify_image(image)
+
+ if(is_list(image)):
+ image_RGB=[]
+ image_list=image
+ for img in image_list:
+
+ img=autumn_process(img)
+ image_RGB.append(img)
+ else:
+ image=autumn_process(image)
+ image_RGB= image
+
+ return image_RGB
+
+def fliph(image): ##function to flip the image on horizontal axis
+ verify_image(image)
+
+ if(is_list(image)):
+ image_RGB=[]
+ image_list=image
+ for img in image_list:
+ image_RGB.append(cv2.flip(img,0))
+ else:
+ image_RGB= cv2.flip(image,0)
+ return image_RGB
+
+def flipv(image): ##function to flip the image on vertical axis
+ verify_image(image)
+
+ if(is_list(image)):
+ image_RGB=[]
+ image_list=image
+ for img in image_list:
+ image_RGB.append(cv2.flip(img,1))
+ else:
+ image_RGB= cv2.flip(image,1)
+ return image_RGB
+
+def random_flip(image): ##function to flip the image on horizontal axis
+ verify_image(image)
+
+ if(is_list(image)):
+ image_RGB=[]
+ image_list=image
+ for img in image_list:
+ p= random.uniform(0,1)
+ if(p>0.5):
+ image_RGB.append(cv2.flip(img,0))
+ else:
+ image_RGB.append(cv2.flip(img,1))
+ else:
+ p= random.uniform(0,1)
+ if(p>0.5):
+ image_RGB=cv2.flip(image,0)
+ else:
+ image_RGB=cv2.flip(image,1)
+ return image_RGB
+
+def manhole_process(image,center,height,width,src_color=(0,0,0)):
+ overlay= image.copy()
+ output= image.copy()
+# cv2.ellipse(overlay, center =center,box=None,color =src_color)
+ cv2.ellipse(overlay, center, (width,height), 0, 0, 360, src_color, -1)
+# cv2.circle(overlay, center, radius, src_color, -1)
+ alp=1
+ cv2.addWeighted(overlay, alp, output, 1 -alp ,0, output)
+ return output
+
+err_invalid_center_manhole="center should be in the format (x,y)"
+err_invalid_height_width_manhole="height and width should be positive integers."
+def add_manhole(image,center=-1,color=(120,120,120),height=1,width=1, type='closed'): ##function to flip the image on horizontal axis
+ verify_image(image)
+
+ if(center!=-1):
+ if not(is_tuple(center) and is_numeric_list_or_tuple(center) and len(center)==2):
+ raise Exception(err_invalid_center_manhole)
+ if not (is_numeric(height) and is_numeric(width) and height>0 and width>0):
+ raise Exception(err_invalid_height_width_manhole)
+ if color==(120,120,120):
+ if type=='closed':
+ color=(67,70,75)
+ elif type=='open':
+ color=(0,0,0)
+
+ if(is_list(image)):
+ image_RGB=[]
+ image_list=image
+ for img in image_list:
+ height_t=height
+ width_t=width
+ center_t=center
+ if height==1:
+ height_t=img.shape[0]//25
+ if width==1:
+ width_t=int(img.shape[0]*3//25)
+ if center==-1:
+ center_t= (img.shape[0]-100, img.shape[1]//2)
+ image_RGB.append(manhole_process(img,center_t,height_t,width_t,color))
+ else:
+ height_t=height
+ width_t=width
+ center_t=center
+ if height==1:
+ height_t=image.shape[0]//25
+ if width==1:
+ width_t=int(image.shape[0]*3//25)
+ if center==-1:
+ center= (image.shape[0]-100, image.shape[1]//2)
+ image_RGB= manhole_process(image,center_t,height_t,width_t,color)
+ return image_RGB
+
+def exposure_process(image):
+ image= np.copy(image)
+ img_yuv = cv2.cvtColor(image, cv2.COLOR_BGR2YUV)
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(4,4))
+ ones= np.ones(img_yuv[:,:,0].shape)
+ ones[img_yuv[:,:,0]>150]= 0.85
+ img_yuv[:,:,0]= img_yuv[:,:,0]*ones
+
+ img_yuv[:,:,0] = clahe.apply(img_yuv[:,:,0])
+ img_yuv[:,:,0] = cv2.equalizeHist(img_yuv[:,:,0])
+ img_yuv[:,:,0] = clahe.apply(img_yuv[:,:,0])
+
+ image_res = cv2.cvtColor(img_yuv, cv2.COLOR_YUV2BGR)
+ image_res= cv2.fastNlMeansDenoisingColored(image_res,None,3,3,7,21)
+ return image_res
+
+def correct_exposure(image):
+ verify_image(image)
+ if(is_list(image)):
+ image_RGB=[]
+ image_list=image
+ for img in image_list:
+ image_RGB.append(exposure_process(img))
+ else:
+ image_RGB= exposure_process(image)
+ return image_RGB
+
+err_aug_type='wrong augmentation function is defined'
+err_aug_list_type='aug_types should be a list of string function names'
+err_aug_volume='volume type can only be "same" or "expand"'
+def augment_random(image, aug_types="", volume='expand' ):
+
+ aug_types_all=["random_brightness","add_shadow","add_snow","add_rain","add_fog","add_gravel","add_sun_flare","add_speed","add_autumn","random_flip","add_manhole"]
+ if aug_types=="":
+ aug_types=aug_types_all
+ output=[]
+ if not(is_list(aug_types)):
+ raise Exception(err_aug_list_type)
+
+ if volume=='expand':
+ for aug_type in aug_types:
+
+ if not(aug_type in aug_types_all):
+ raise Exception(err_aug_type)
+ command=aug_type+'(image)'
+ result=eval(command)
+ if(is_list(result)):
+ output+=result
+ else:
+ output.append(result)
+ elif volume=='same':
+ verify_image(image)
+ for aug_type in aug_types:
+ if not(aug_type in aug_types_all):
+ raise Exception(err_aug_type)
+ if(is_list(image)):
+ image_list=image
+ for img in image_list:
+ selected_aug=aug_types[random.randint(0,len(aug_types)-1)]
+ command=selected_aug+'(img)'
+ output.append(eval(command))
+ else:
+ selected_aug=aug_types[random.randint(0,len(aug_types)-1)]
+ command=selected_aug+'(image)'
+ output=eval(command)
+
+ else:
+ raise Exception(err_aug_volume)
+
+ return output
\ No newline at end of file