diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..13357729bfdcd22f326c283ffa3bfca02346c638 --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +__pycache__/ +interal/__pycache__/ +tests/__pycache__/ +.DS_Store +.vscode/ +.idea/ +__MACOSX/ +exp/ +data/ +assets/ +test.py +test2.py +*.mp4 +*.ply +scripts/train_360_debug.sh \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d645695673349e3947e8e5ae42332d0ac3164cd7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 154df8298fab5ecf322016157858e08cd1bccbe1..9df9e1e79772f5e280aa00cd5d91b01baf1f6fdf 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,252 @@ ---- -license: apache-2.0 ---- +# ZipNeRF + +An unofficial pytorch implementation of +"Zip-NeRF: Anti-Aliased Grid-Based Neural Radiance Fields" +[https://arxiv.org/abs/2304.06706](https://arxiv.org/abs/2304.06706). +This work is based on [multinerf](https://github.com/google-research/multinerf), so features in refnerf,rawnerf,mipnerf360 are also available. + +## News +- (6.22) Add extracting mesh through tsdf; add [gradient scaling](https://gradient-scaling.github.io/) for near plane floaters. +- (5.26) Implement the latest version of ZipNeRF [https://arxiv.org/abs/2304.06706](https://arxiv.org/abs/2304.06706). +- (5.22) Add extracting mesh; add logging,checkpointing system + +## Results +New results(5.27): + +360_v2: + +https://github.com/SuLvXiangXin/zipnerf-pytorch/assets/83005605/2b276e48-2dc4-4508-8441-e90ec963f7d9 + + +360_v2_glo:(fewer floaters, but worse metric) + + +https://github.com/SuLvXiangXin/zipnerf-pytorch/assets/83005605/bddb5610-2a4f-4981-8e17-71326a24d291 + + + + + + +mesh results(5.27): + +![mesh](https://github.com/SuLvXiangXin/zipnerf-pytorch/assets/83005605/35866fa7-fe6a-44fe-9590-05d594bdb8cd) + + + +Mipnerf360(PSNR): + +| | bicycle | garden | stump | room | counter | kitchen | bonsai | +|:---------:|:-------:|:------:|:-----:|:-----:|:-------:|:-------:|:------:| +| Paper | 25.80 | 28.20 | 27.55 | 32.65 | 29.38 | 32.50 | 34.46 | +| This repo | 25.44 | 27.98 | 26.75 | 32.13 | 29.10 | 32.63 | 34.20 | + + +Blender(PSNR): + +| | chair | drums | ficus | hotdog | lego | materials | mic | ship | +|:---------:|:-----:|:-----:|:-----:|:------:|:-----:|:---------:|:-----:|:-----:| +| Paper | 34.84 | 25.84 | 33.90 | 37.14 | 34.84 | 31.66 | 35.15 | 31.38 | +| This repo | 35.26 | 25.51 | 32.66 | 36.56 | 35.04 | 29.43 | 34.93 | 31.38 | + +For Mipnerf360 dataset, the model is trained with a downsample factor of 4 for outdoor scene and 2 for indoor scene(same as in paper). +Training speed is about 1.5x slower than paper(1.5 hours on 8 A6000). + +The hash decay loss seems to have little effect(?), as many floaters can be found in the final results in both experiments (especially in Blender). + +## Install + +``` +# Clone the repo. +git clone https://github.com/SuLvXiangXin/zipnerf-pytorch.git +cd zipnerf-pytorch + +# Make a conda environment. +conda create --name zipnerf python=3.9 +conda activate zipnerf + +# Install requirements. +pip install -r requirements.txt + +# Install other extensions +pip install ./gridencoder + +# Install nvdiffrast (optional, for textured mesh) +git clone https://github.com/NVlabs/nvdiffrast +pip install ./nvdiffrast + +# Install a specific cuda version of torch_scatter +# see more detail at https://github.com/rusty1s/pytorch_scatter +CUDA=cu117 +pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.0+${CUDA}.html +``` + +## Dataset +[mipnerf360](http://storage.googleapis.com/gresearch/refraw360/360_v2.zip) + +[refnerf](https://storage.googleapis.com/gresearch/refraw360/ref.zip) + +[nerf_synthetic](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) + +[nerf_llff_data](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) + +``` +mkdir data +cd data + +# e.g. mipnerf360 data +wget http://storage.googleapis.com/gresearch/refraw360/360_v2.zip +unzip 360_v2.zip +``` + +## Train +``` +# Configure your training (DDP? fp16? ...) +# see https://huggingface.co/docs/accelerate/index for details +accelerate config + +# Where your data is +DATA_DIR=data/360_v2/bicycle +EXP_NAME=360_v2/bicycle + +# Experiment will be conducted under "exp/${EXP_NAME}" folder +# "--gin_configs=configs/360.gin" can be seen as a default config +# and you can add specific config useing --gin_bindings="..." +accelerate launch train.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXP_NAME}'" \ + --gin_bindings="Config.factor = 4" + +# or you can also run without accelerate (without DDP) +CUDA_VISIBLE_DEVICES=0 python train.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXP_NAME}'" \ + --gin_bindings="Config.factor = 4" + +# alternatively you can use an example training script +bash scripts/train_360.sh + +# blender dataset +bash scripts/train_blender.sh + +# metric, render image, etc can be viewed through tensorboard +tensorboard --logdir "exp/${EXP_NAME}" + +``` + +### Render +Rendering results can be found in the directory `exp/${EXP_NAME}/render` +``` +accelerate launch render.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXP_NAME}'" \ + --gin_bindings="Config.render_path = True" \ + --gin_bindings="Config.render_path_frames = 480" \ + --gin_bindings="Config.render_video_fps = 60" \ + --gin_bindings="Config.factor = 4" + +# alternatively you can use an example rendering script +bash scripts/render_360.sh +``` +## Evaluate +Evaluating results can be found in the directory `exp/${EXP_NAME}/test_preds` +``` +# using the same exp_name as in training +accelerate launch eval.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXP_NAME}'" \ + --gin_bindings="Config.factor = 4" + + +# alternatively you can use an example evaluating script +bash scripts/eval_360.sh +``` + +## Extract mesh +Mesh results can be found in the directory `exp/${EXP_NAME}/mesh` +``` +# more configuration can be found in internal/configs.py +accelerate launch extract.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXP_NAME}'" \ + --gin_bindings="Config.factor = 4" +# --gin_bindings="Config.mesh_radius = 1" # (optional) smaller for more details e.g. 0.2 in bicycle scene +# --gin_bindings="Config.isosurface_threshold = 20" # (optional) empirical value +# --gin_bindings="Config.mesh_voxels=134217728" # (optional) number of voxels used to extract mesh, e.g. 134217728 equals to 512**3 . Smaller values may solve OutoFMemoryError +# --gin_bindings="Config.vertex_color = True" # (optional) saving mesh with vertex color instead of atlas which is much slower but with more details. +# --gin_bindings="Config.vertex_projection = True" # (optional) use projection for vertex color + +# or extracting mesh using tsdf method +accelerate launch extract.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXP_NAME}'" \ + --gin_bindings="Config.factor = 4" + +# alternatively you can use an example script +bash scripts/extract_360.sh +``` + +## OutOfMemory +you can decrease the total batch size by +adding e.g. `--gin_bindings="Config.batch_size = 8192" `, +or decrease the test chunk size by adding e.g. `--gin_bindings="Config.render_chunk_size = 8192" `, +or use more GPU by configure `accelerate config` . + + +## Preparing custom data +More details can be found at https://github.com/google-research/multinerf +``` +DATA_DIR=my_dataset_dir +bash scripts/local_colmap_and_resize.sh ${DATA_DIR} +``` + +## TODO +- [x] Add MultiScale training and testing + +## Citation +``` +@misc{barron2023zipnerf, + title={Zip-NeRF: Anti-Aliased Grid-Based Neural Radiance Fields}, + author={Jonathan T. Barron and Ben Mildenhall and Dor Verbin and Pratul P. Srinivasan and Peter Hedman}, + year={2023}, + eprint={2304.06706}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} + +@misc{multinerf2022, + title={{MultiNeRF}: {A} {Code} {Release} for {Mip-NeRF} 360, {Ref-NeRF}, and {RawNeRF}}, + author={Ben Mildenhall and Dor Verbin and Pratul P. Srinivasan and Peter Hedman and Ricardo Martin-Brualla and Jonathan T. Barron}, + year={2022}, + url={https://github.com/google-research/multinerf}, +} + +@Misc{accelerate, + title = {Accelerate: Training and inference at scale made simple, efficient and adaptable.}, + author = {Sylvain Gugger, Lysandre Debut, Thomas Wolf, Philipp Schmid, Zachary Mueller, Sourab Mangrulkar}, + howpublished = {\url{https://github.com/huggingface/accelerate}}, + year = {2022} +} + +@misc{torch-ngp, + Author = {Jiaxiang Tang}, + Year = {2022}, + Note = {https://github.com/ashawkey/torch-ngp}, + Title = {Torch-ngp: a PyTorch implementation of instant-ngp} +} +``` + +## Acknowledgements +This work is based on my another repo https://github.com/SuLvXiangXin/multinerf-pytorch, +which is basically a pytorch translation from [multinerf](https://github.com/google-research/multinerf) + +- Thanks to [multinerf](https://github.com/google-research/multinerf) for amazing multinerf(MipNeRF360,RefNeRF,RawNeRF) implementation +- Thanks to [accelerate](https://github.com/huggingface/accelerate) for distributed training +- Thanks to [torch-ngp](https://github.com/ashawkey/torch-ngp) for super useful hashencoder +- Thanks to [Yurui Chen](https://github.com/519401113) for discussing the details of the paper. diff --git a/configs/360.gin b/configs/360.gin new file mode 100644 index 0000000000000000000000000000000000000000..d39fe8bab2c3e2add7fc46fb152dbdfc4850991e --- /dev/null +++ b/configs/360.gin @@ -0,0 +1,15 @@ +Config.exp_name = 'test' +Config.dataset_loader = 'llff' +Config.near = 0.2 +Config.far = 1e6 +Config.factor = 4 + +Model.raydist_fn = 'power_transformation' +Model.opaque_background = True + +PropMLP.disable_density_normals = True +PropMLP.disable_rgb = True +PropMLP.grid_level_dim = 1 + +NerfMLP.disable_density_normals = True + diff --git a/configs/360_glo.gin b/configs/360_glo.gin new file mode 100644 index 0000000000000000000000000000000000000000..42a24362183c475d225035a8d0ab63204a3ee48f --- /dev/null +++ b/configs/360_glo.gin @@ -0,0 +1,15 @@ +Config.dataset_loader = 'llff' +Config.near = 0.2 +Config.far = 1e6 +Config.factor = 4 + +Model.raydist_fn = 'power_transformation' +Model.num_glo_features = 128 +Model.opaque_background = True + +PropMLP.disable_density_normals = True +PropMLP.disable_rgb = True +PropMLP.grid_level_dim = 1 + + +NerfMLP.disable_density_normals = True diff --git a/configs/blender.gin b/configs/blender.gin new file mode 100644 index 0000000000000000000000000000000000000000..20f74c9d27afb78a8c5c636c0572957226eb645a --- /dev/null +++ b/configs/blender.gin @@ -0,0 +1,15 @@ +Config.exp_name = 'test' +Config.dataset_loader = 'blender' +Config.near = 2 +Config.far = 6 +Config.factor = 0 +Config.hash_decay_mults = 10 + +Model.raydist_fn = None + +PropMLP.disable_density_normals = True +PropMLP.disable_rgb = True +PropMLP.grid_level_dim = 1 + +NerfMLP.disable_density_normals = True + diff --git a/configs/blender_refnerf.gin b/configs/blender_refnerf.gin new file mode 100644 index 0000000000000000000000000000000000000000..42da4df7dc604a620cbb2b9ca3cc9b4db330eea0 --- /dev/null +++ b/configs/blender_refnerf.gin @@ -0,0 +1,41 @@ +Config.dataset_loader = 'blender' +Config.batching = 'single_image' +Config.near = 2 +Config.far = 6 + +Config.eval_render_interval = 5 +Config.compute_normal_metrics = True +Config.data_loss_type = 'mse' +Config.distortion_loss_mult = 0.0 +Config.orientation_loss_mult = 0.1 +Config.orientation_loss_target = 'normals_pred' +Config.predicted_normal_loss_mult = 3e-4 +Config.orientation_coarse_loss_mult = 0.01 +Config.predicted_normal_coarse_loss_mult = 3e-5 +Config.interlevel_loss_mult = 0.0 +Config.data_coarse_loss_mult = 0.1 +Config.adam_eps = 1e-8 + +Model.num_levels = 2 +Model.single_mlp = True +Model.num_prop_samples = 128 # This needs to be set despite single_mlp = True. +Model.num_nerf_samples = 128 +Model.anneal_slope = 0. +Model.dilation_multiplier = 0. +Model.dilation_bias = 0. +Model.single_jitter = False +Model.resample_padding = 0.01 +Model.distinct_prop = False + +NerfMLP.disable_density_normals = False +NerfMLP.enable_pred_normals = True +NerfMLP.use_directional_enc = True +NerfMLP.use_reflections = True +NerfMLP.deg_view = 5 +NerfMLP.enable_pred_roughness = True +NerfMLP.use_diffuse_color = True +NerfMLP.use_specular_tint = True +NerfMLP.use_n_dot_v = True +NerfMLP.bottleneck_width = 128 +NerfMLP.density_bias = 0.5 +NerfMLP.max_deg_point = 16 diff --git a/configs/llff_256.gin b/configs/llff_256.gin new file mode 100644 index 0000000000000000000000000000000000000000..eeeba4fc80c029a298d96d6b71d743d3ccc45840 --- /dev/null +++ b/configs/llff_256.gin @@ -0,0 +1,19 @@ +Config.dataset_loader = 'llff' +Config.near = 0. +Config.far = 1. +Config.factor = 4 +Config.forward_facing = True +Config.adam_eps = 1e-8 + +Model.opaque_background = True +Model.num_levels = 2 +Model.num_prop_samples = 128 +Model.num_nerf_samples = 32 + +PropMLP.disable_density_normals = True +PropMLP.disable_rgb = True + +NerfMLP.disable_density_normals = True + +NerfMLP.max_deg_point = 16 +PropMLP.max_deg_point = 16 diff --git a/configs/llff_512.gin b/configs/llff_512.gin new file mode 100644 index 0000000000000000000000000000000000000000..eeeba4fc80c029a298d96d6b71d743d3ccc45840 --- /dev/null +++ b/configs/llff_512.gin @@ -0,0 +1,19 @@ +Config.dataset_loader = 'llff' +Config.near = 0. +Config.far = 1. +Config.factor = 4 +Config.forward_facing = True +Config.adam_eps = 1e-8 + +Model.opaque_background = True +Model.num_levels = 2 +Model.num_prop_samples = 128 +Model.num_nerf_samples = 32 + +PropMLP.disable_density_normals = True +PropMLP.disable_rgb = True + +NerfMLP.disable_density_normals = True + +NerfMLP.max_deg_point = 16 +PropMLP.max_deg_point = 16 diff --git a/configs/llff_raw.gin b/configs/llff_raw.gin new file mode 100644 index 0000000000000000000000000000000000000000..343f226b90af80ff45c0c1495dc0065f639e5842 --- /dev/null +++ b/configs/llff_raw.gin @@ -0,0 +1,73 @@ +# General LLFF settings + +Config.dataset_loader = 'llff' +Config.near = 0. +Config.far = 1. +Config.factor = 4 +Config.forward_facing = True + +PropMLP.disable_density_normals = True # Turn this off if using orientation loss. +PropMLP.disable_rgb = True + +NerfMLP.disable_density_normals = True # Turn this off if using orientation loss. + +NerfMLP.max_deg_point = 16 +PropMLP.max_deg_point = 16 + +Config.train_render_every = 5000 + + +########################## RawNeRF specific settings ########################## + +Config.rawnerf_mode = True +Config.data_loss_type = 'rawnerf' +Config.apply_bayer_mask = True +Model.learned_exposure_scaling = True + +Model.num_levels = 2 +Model.num_prop_samples = 128 # Using extra samples for now because of noise instability. +Model.num_nerf_samples = 128 +Model.opaque_background = True +Model.distinct_prop = False + +# RGB activation we use for linear color outputs is exp(x - 5). +NerfMLP.rgb_padding = 0. +NerfMLP.rgb_activation = @math.safe_exp +NerfMLP.rgb_bias = -5. +PropMLP.rgb_padding = 0. +PropMLP.rgb_activation = @math.safe_exp +PropMLP.rgb_bias = -5. + +## Experimenting with the various regularizers and losses: +Config.interlevel_loss_mult = .0 # Turning off interlevel for now (default = 1.). +Config.distortion_loss_mult = .01 # Distortion loss helps with floaters (default = .01). +Config.orientation_loss_mult = 0. # Orientation loss also not great (try .01). +Config.data_coarse_loss_mult = 0.1 # Setting this to match old MipNeRF. + +## Density noise used in original NeRF: +NerfMLP.density_noise = 1. +PropMLP.density_noise = 1. + +## Use a single MLP for all rounds of sampling: +Model.single_mlp = True + +## Some algorithmic settings to match the paper: +Model.anneal_slope = 0. +Model.dilation_multiplier = 0. +Model.dilation_bias = 0. +Model.single_jitter = False +NerfMLP.weight_init = 'glorot_uniform' +PropMLP.weight_init = 'glorot_uniform' + +## Training hyperparameters used in the paper: +Config.batch_size = 16384 +Config.render_chunk_size = 16384 +Config.lr_init = 1e-3 +Config.lr_final = 1e-5 +Config.max_steps = 500000 +Config.checkpoint_every = 25000 +Config.lr_delay_steps = 2500 +Config.lr_delay_mult = 0.01 +Config.grad_max_norm = 0.1 +Config.grad_max_val = 0.1 +Config.adam_eps = 1e-8 diff --git a/configs/multi360.gin b/configs/multi360.gin new file mode 100644 index 0000000000000000000000000000000000000000..e9bef1a30c50fdc253d4aef2b4aebc897be6803c --- /dev/null +++ b/configs/multi360.gin @@ -0,0 +1,5 @@ +include 'configs/360.gin' +Config.multiscale = True +Config.multiscale_levels = 4 + + diff --git a/eval.py b/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..167a5509876e2647fd02bea701e08b61c2dd856a --- /dev/null +++ b/eval.py @@ -0,0 +1,307 @@ +import logging +import os +import sys +import time +import accelerate +from absl import app +import gin +from internal import configs +from internal import datasets +from internal import image +from internal import models +from internal import raw_utils +from internal import ref_utils +from internal import train_utils +from internal import checkpoints +from internal import utils +from internal import vis +import numpy as np +import torch +import tensorboardX +from torch.utils._pytree import tree_map + +configs.define_common_flags() + + +def summarize_results(folder, scene_names, num_buckets): + metric_names = ['psnrs', 'ssims', 'lpips'] + num_iters = 1000000 + precisions = [3, 4, 4, 4] + + results = [] + for scene_name in scene_names: + test_preds_folder = os.path.join(folder, scene_name, 'test_preds') + values = [] + for metric_name in metric_names: + filename = os.path.join(folder, scene_name, 'test_preds', f'{metric_name}_{num_iters}.txt') + with utils.open_file(filename) as f: + v = np.array([float(s) for s in f.readline().split(' ')]) + values.append(np.mean(np.reshape(v, [-1, num_buckets]), 0)) + results.append(np.concatenate(values)) + avg_results = np.mean(np.array(results), 0) + + psnr, ssim, lpips = np.mean(np.reshape(avg_results, [-1, num_buckets]), 1) + + mse = np.exp(-0.1 * np.log(10.) * psnr) + dssim = np.sqrt(1 - ssim) + avg_avg = np.exp(np.mean(np.log(np.array([mse, dssim, lpips])))) + + s = [] + for i, v in enumerate(np.reshape(avg_results, [-1, num_buckets])): + s.append(' '.join([f'{s:0.{precisions[i]}f}' for s in v])) + s.append(f'{avg_avg:0.{precisions[-1]}f}') + return ' | '.join(s) + + +def main(unused_argv): + config = configs.load_config() + config.exp_path = os.path.join('exp', config.exp_name) + config.checkpoint_dir = os.path.join(config.exp_path, 'checkpoints') + config.render_dir = os.path.join(config.exp_path, 'render') + + accelerator = accelerate.Accelerator() + + # setup logger + logging.basicConfig( + format="%(asctime)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + force=True, + handlers=[logging.StreamHandler(sys.stdout), + logging.FileHandler(os.path.join(config.exp_path, 'log_eval.txt'))], + level=logging.INFO, + ) + sys.excepthook = utils.handle_exception + logger = accelerate.logging.get_logger(__name__) + logger.info(config) + logger.info(accelerator.state, main_process_only=False) + + config.world_size = accelerator.num_processes + config.global_rank = accelerator.process_index + accelerate.utils.set_seed(config.seed, device_specific=True) + model = models.Model(config=config) + model.eval() + model.to(accelerator.device) + + dataset = datasets.load_dataset('test', config.data_dir, config) + dataloader = torch.utils.data.DataLoader(np.arange(len(dataset)), + shuffle=False, + batch_size=1, + collate_fn=dataset.collate_fn, + ) + tb_process_fn = lambda x: x.transpose(2, 0, 1) if len(x.shape) == 3 else x[None] + if config.rawnerf_mode: + postprocess_fn = dataset.metadata['postprocess_fn'] + else: + postprocess_fn = lambda z: z + + if config.eval_raw_affine_cc: + cc_fun = raw_utils.match_images_affine + else: + cc_fun = image.color_correct + + model = accelerator.prepare(model) + + metric_harness = image.MetricHarness() + + last_step = 0 + out_dir = os.path.join(config.exp_path, + 'path_renders' if config.render_path else 'test_preds') + path_fn = lambda x: os.path.join(out_dir, x) + + if not config.eval_only_once: + summary_writer = tensorboardX.SummaryWriter( + os.path.join(config.exp_path, 'eval')) + while True: + step = checkpoints.restore_checkpoint(config.checkpoint_dir, accelerator, logger) + if step <= last_step: + logger.info(f'Checkpoint step {step} <= last step {last_step}, sleeping.') + time.sleep(10) + continue + logger.info(f'Evaluating checkpoint at step {step}.') + if config.eval_save_output and (not utils.isdir(out_dir)): + utils.makedirs(out_dir) + + num_eval = min(dataset.size, config.eval_dataset_limit) + perm = np.random.permutation(num_eval) + showcase_indices = np.sort(perm[:config.num_showcase_images]) + metrics = [] + metrics_cc = [] + showcases = [] + render_times = [] + for idx, batch in enumerate(dataloader): + batch = accelerate.utils.send_to_device(batch, accelerator.device) + eval_start_time = time.time() + if idx >= num_eval: + logger.info(f'Skipping image {idx + 1}/{dataset.size}') + continue + logger.info(f'Evaluating image {idx + 1}/{dataset.size}') + rendering = models.render_image(model, accelerator, + batch, False, 1, config) + + if not accelerator.is_main_process: # Only record via host 0. + continue + + render_times.append((time.time() - eval_start_time)) + logger.info(f'Rendered in {render_times[-1]:0.3f}s') + + cc_start_time = time.time() + rendering['rgb_cc'] = cc_fun(rendering['rgb'], batch['rgb']) + + rendering = tree_map(lambda x: x.detach().cpu().numpy() if x is not None else None, rendering) + batch = tree_map(lambda x: x.detach().cpu().numpy() if x is not None else None, batch) + + gt_rgb = batch['rgb'] + logger.info(f'Color corrected in {(time.time() - cc_start_time):0.3f}s') + + if not config.eval_only_once and idx in showcase_indices: + showcase_idx = idx if config.deterministic_showcase else len(showcases) + showcases.append((showcase_idx, rendering, batch)) + if not config.render_path: + rgb = postprocess_fn(rendering['rgb']) + rgb_cc = postprocess_fn(rendering['rgb_cc']) + rgb_gt = postprocess_fn(gt_rgb) + + if config.eval_quantize_metrics: + # Ensures that the images written to disk reproduce the metrics. + rgb = np.round(rgb * 255) / 255 + rgb_cc = np.round(rgb_cc * 255) / 255 + + if config.eval_crop_borders > 0: + crop_fn = lambda x, c=config.eval_crop_borders: x[c:-c, c:-c] + rgb = crop_fn(rgb) + rgb_cc = crop_fn(rgb_cc) + rgb_gt = crop_fn(rgb_gt) + + metric = metric_harness(rgb, rgb_gt) + metric_cc = metric_harness(rgb_cc, rgb_gt) + + if config.compute_disp_metrics: + for tag in ['mean', 'median']: + key = f'distance_{tag}' + if key in rendering: + disparity = 1 / (1 + rendering[key]) + metric[f'disparity_{tag}_mse'] = float( + ((disparity - batch['disps']) ** 2).mean()) + + if config.compute_normal_metrics: + weights = rendering['acc'] * batch['alphas'] + normalized_normals_gt = ref_utils.l2_normalize_np(batch['normals']) + for key, val in rendering.items(): + if key.startswith('normals') and val is not None: + normalized_normals = ref_utils.l2_normalize_np(val) + metric[key + '_mae'] = ref_utils.compute_weighted_mae_np( + weights, normalized_normals, normalized_normals_gt) + + for m, v in metric.items(): + logger.info(f'{m:30s} = {v:.4f}') + + metrics.append(metric) + metrics_cc.append(metric_cc) + + if config.eval_save_output and (config.eval_render_interval > 0): + if (idx % config.eval_render_interval) == 0: + utils.save_img_u8(postprocess_fn(rendering['rgb']), + path_fn(f'color_{idx:03d}.png')) + utils.save_img_u8(postprocess_fn(rendering['rgb_cc']), + path_fn(f'color_cc_{idx:03d}.png')) + + for key in ['distance_mean', 'distance_median']: + if key in rendering: + utils.save_img_f32(rendering[key], + path_fn(f'{key}_{idx:03d}.tiff')) + + for key in ['normals']: + if key in rendering: + utils.save_img_u8(rendering[key] / 2. + 0.5, + path_fn(f'{key}_{idx:03d}.png')) + + utils.save_img_f32(rendering['acc'], path_fn(f'acc_{idx:03d}.tiff')) + + if (not config.eval_only_once) and accelerator.is_main_process: + summary_writer.add_scalar('eval_median_render_time', np.median(render_times), + step) + for name in metrics[0]: + scores = [m[name] for m in metrics] + summary_writer.add_scalar('eval_metrics/' + name, np.mean(scores), step) + summary_writer.add_histogram('eval_metrics/' + 'perimage_' + name, scores, + step) + for name in metrics_cc[0]: + scores = [m[name] for m in metrics_cc] + summary_writer.add_scalar('eval_metrics_cc/' + name, np.mean(scores), step) + summary_writer.add_histogram('eval_metrics_cc/' + 'perimage_' + name, + scores, step) + + for i, r, b in showcases: + if config.vis_decimate > 1: + d = config.vis_decimate + decimate_fn = lambda x, d=d: None if x is None else x[::d, ::d] + else: + decimate_fn = lambda x: x + r = tree_map(decimate_fn, r) + b = tree_map(decimate_fn, b) + visualizations = vis.visualize_suite(r, b) + for k, v in visualizations.items(): + if k == 'color': + v = postprocess_fn(v) + summary_writer.add_image(f'output_{k}_{i}', tb_process_fn(v), step) + if not config.render_path: + target = postprocess_fn(b['rgb']) + summary_writer.add_image(f'true_color_{i}', tb_process_fn(target), step) + pred = postprocess_fn(visualizations['color']) + residual = np.clip(pred - target + 0.5, 0, 1) + summary_writer.add_image(f'true_residual_{i}', tb_process_fn(residual), step) + if config.compute_normal_metrics: + summary_writer.add_image(f'true_normals_{i}', tb_process_fn(b['normals']) / 2. + 0.5, + step) + + if (config.eval_save_output and (not config.render_path) and + accelerator.is_main_process): + with utils.open_file(path_fn(f'render_times_{step}.txt'), 'w') as f: + f.write(' '.join([str(r) for r in render_times])) + logger.info(f'metrics:') + results = {} + num_buckets = config.multiscale_levels if config.multiscale else 1 + for name in metrics[0]: + with utils.open_file(path_fn(f'metric_{name}_{step}.txt'), 'w') as f: + ms = [m[name] for m in metrics] + f.write(' '.join([str(m) for m in ms])) + results[name] = ' | '.join( + list(map(str, np.mean(np.array(ms).reshape([-1, num_buckets]), 0).tolist()))) + with utils.open_file(path_fn(f'metric_avg_{step}.txt'), 'w') as f: + for name in metrics[0]: + f.write(f'{name}: {results[name]}\n') + logger.info(f'{name}: {results[name]}') + logger.info(f'metrics_cc:') + results_cc = {} + for name in metrics_cc[0]: + with utils.open_file(path_fn(f'metric_cc_{name}_{step}.txt'), 'w') as f: + ms = [m[name] for m in metrics_cc] + f.write(' '.join([str(m) for m in ms])) + results_cc[name] = ' | '.join( + list(map(str, np.mean(np.array(ms).reshape([-1, num_buckets]), 0).tolist()))) + with utils.open_file(path_fn(f'metric_cc_avg_{step}.txt'), 'w') as f: + for name in metrics[0]: + f.write(f'{name}: {results_cc[name]}\n') + logger.info(f'{name}: {results_cc[name]}') + if config.eval_save_ray_data: + for i, r, b in showcases: + rays = {k: v for k, v in r.items() if 'ray_' in k} + np.set_printoptions(threshold=sys.maxsize) + with utils.open_file(path_fn(f'ray_data_{step}_{i}.txt'), 'w') as f: + f.write(repr(rays)) + + if config.eval_only_once: + break + if config.early_exit_steps is not None: + num_steps = config.early_exit_steps + else: + num_steps = config.max_steps + if int(step) >= num_steps: + break + last_step = step + logger.info('Finish evaluation.') + + +if __name__ == '__main__': + with gin.config_scope('eval'): + app.run(main) diff --git a/extract.py b/extract.py new file mode 100644 index 0000000000000000000000000000000000000000..187db41270f893fbe3094fc496b9eb8e95f4caf7 --- /dev/null +++ b/extract.py @@ -0,0 +1,638 @@ +import logging +import os +import sys + +import cv2 +import numpy as np +from absl import app +import gin +from internal import configs +from internal import datasets +from internal import models +from internal import utils +from internal import coord +from internal import checkpoints +import torch +import accelerate +from tqdm import tqdm +from torch.utils._pytree import tree_map +import torch.nn.functional as F +from skimage import measure +import trimesh +import pymeshlab as pml + +configs.define_common_flags() + + +@torch.no_grad() +def evaluate_density(model, accelerator: accelerate.Accelerator, + points, config: configs.Config, std_value=0.0): + """ + Evaluate a signed distance function (SDF) for a batch of points. + + Args: + sdf: A callable function that takes a tensor of size (N, 3) containing + 3D points and returns a tensor of size (N,) with the SDF values. + points: A torch tensor containing 3D points. + + Returns: + A torch tensor with the SDF values evaluated at the given points. + """ + z = [] + for _, pnts in enumerate(tqdm(torch.split(points, config.render_chunk_size, dim=0), + desc="Evaluating density", leave=False, + disable=not accelerator.is_main_process)): + rays_remaining = pnts.shape[0] % accelerator.num_processes + if rays_remaining != 0: + padding = accelerator.num_processes - rays_remaining + pnts = torch.cat([pnts, torch.zeros_like(pnts[-padding:])], dim=0) + else: + padding = 0 + rays_per_host = pnts.shape[0] // accelerator.num_processes + start, stop = accelerator.process_index * rays_per_host, \ + (accelerator.process_index + 1) * rays_per_host + chunk_means = pnts[start:stop] + chunk_stds = torch.full_like(chunk_means[..., 0], std_value) + raw_density = model.nerf_mlp.predict_density(chunk_means[:, None], chunk_stds[:, None], no_warp=True)[0] + density = F.softplus(raw_density + model.nerf_mlp.density_bias) + density = accelerator.gather(density) + if padding > 0: + density = density[: -padding] + z.append(density) + z = torch.cat(z, dim=0) + return z + + +@torch.no_grad() +def evaluate_color(model, accelerator: accelerate.Accelerator, + points, config: configs.Config, std_value=0.0): + """ + Evaluate a signed distance function (SDF) for a batch of points. + + Args: + sdf: A callable function that takes a tensor of size (N, 3) containing + 3D points and returns a tensor of size (N,) with the SDF values. + points: A torch tensor containing 3D points. + + Returns: + A torch tensor with the SDF values evaluated at the given points. + """ + z = [] + for _, pnts in enumerate(tqdm(torch.split(points, config.render_chunk_size, dim=0), + desc="Evaluating color", + disable=not accelerator.is_main_process)): + rays_remaining = pnts.shape[0] % accelerator.num_processes + if rays_remaining != 0: + padding = accelerator.num_processes - rays_remaining + pnts = torch.cat([pnts, torch.zeros_like(pnts[-padding:])], dim=0) + else: + padding = 0 + rays_per_host = pnts.shape[0] // accelerator.num_processes + start, stop = accelerator.process_index * rays_per_host, \ + (accelerator.process_index + 1) * rays_per_host + chunk_means = pnts[start:stop] + chunk_stds = torch.full_like(chunk_means[..., 0], std_value) + chunk_viewdirs = torch.zeros_like(chunk_means) + ray_results = model.nerf_mlp(False, chunk_means[:, None, None], chunk_stds[:, None, None], + chunk_viewdirs) + rgb = ray_results['rgb'][:, 0] + rgb = accelerator.gather(rgb) + if padding > 0: + rgb = rgb[: -padding] + z.append(rgb) + z = torch.cat(z, dim=0) + return z + + +@torch.no_grad() +def evaluate_color_projection(model, accelerator: accelerate.Accelerator, vertices, faces, config: configs.Config): + normals = auto_normals(vertices, faces.long()) + viewdirs = -normals + origins = vertices - 0.005 * viewdirs + vc = [] + chunk = config.render_chunk_size + model.num_levels = 1 + model.opaque_background = True + for i in tqdm(range(0, origins.shape[0], chunk), + desc="Evaluating color projection", + disable=not accelerator.is_main_process): + cur_chunk = min(chunk, origins.shape[0] - i) + rays_remaining = cur_chunk % accelerator.num_processes + rays_per_host = cur_chunk // accelerator.num_processes + if rays_remaining != 0: + padding = accelerator.num_processes - rays_remaining + rays_per_host += 1 + else: + padding = 0 + start = i + accelerator.process_index * rays_per_host + stop = start + rays_per_host + + batch = { + 'origins': origins[start:stop], + 'directions': viewdirs[start:stop], + 'viewdirs': viewdirs[start:stop], + 'cam_dirs': viewdirs[start:stop], + 'radii': torch.full_like(origins[start:stop, ..., :1], 0.000723), + 'near': torch.full_like(origins[start:stop, ..., :1], 0), + 'far': torch.full_like(origins[start:stop, ..., :1], 0.01), + } + batch = accelerator.pad_across_processes(batch) + with accelerator.autocast(): + renderings, ray_history = model( + False, + batch, + compute_extras=False, + train_frac=1) + rgb = renderings[-1]['rgb'] + acc = renderings[-1]['acc'] + + rgb /= acc.clamp_min(1e-5)[..., None] + rgb = rgb.clamp(0, 1) + + rgb = accelerator.gather(rgb) + rgb[torch.isnan(rgb) | torch.isinf(rgb)] = 1 + if padding > 0: + rgb = rgb[: -padding] + vc.append(rgb) + vc = torch.cat(vc, dim=0) + return vc + + +def auto_normals(verts, faces): + i0 = faces[:, 0] + i1 = faces[:, 1] + i2 = faces[:, 2] + + v0 = verts[i0, :] + v1 = verts[i1, :] + v2 = verts[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + v_nrm = torch.zeros_like(verts) + v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + v_nrm = torch.where((v_nrm ** 2).sum(dim=-1, keepdims=True) > 1e-20, v_nrm, + torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=verts.device)) + v_nrm = F.normalize(v_nrm, dim=-1) + return v_nrm + + +def clean_mesh(verts, faces, v_pct=1, min_f=8, min_d=5, repair=True, remesh=True, remesh_size=0.01, logger=None, main_process=True): + # verts: [N, 3] + # faces: [N, 3] + tbar = tqdm(total=9, desc='Clean mesh', leave=False, disable=not main_process) + _ori_vert_shape = verts.shape + _ori_face_shape = faces.shape + + m = pml.Mesh(verts, faces) + ms = pml.MeshSet() + ms.add_mesh(m, 'mesh') # will copy! + + # filters + tbar.set_description('Remove unreferenced vertices') + ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces + tbar.update() + + if v_pct > 0: + tbar.set_description('Remove unreferenced vertices') + ms.meshing_merge_close_vertices(threshold=pml.Percentage(v_pct)) # 1/10000 of bounding box diagonal + tbar.update() + + tbar.set_description('Remove duplicate faces') + ms.meshing_remove_duplicate_faces() # faces defined by the same verts + tbar.update() + + tbar.set_description('Remove null faces') + ms.meshing_remove_null_faces() # faces with area == 0 + tbar.update() + + if min_d > 0: + tbar.set_description('Remove connected component by diameter') + ms.meshing_remove_connected_component_by_diameter(mincomponentdiag=pml.Percentage(min_d)) + tbar.update() + + if min_f > 0: + tbar.set_description('Remove connected component by face number') + ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f) + tbar.update() + + if repair: + # tbar.set_description('Remove t vertices') + # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True) + tbar.set_description('Repair non manifold edges') + ms.meshing_repair_non_manifold_edges(method=0) + tbar.update() + tbar.set_description('Repair non manifold vertices') + ms.meshing_repair_non_manifold_vertices(vertdispratio=0) + tbar.update() + else: + tbar.update(2) + if remesh: + # tbar.set_description('Coord taubin smoothing') + # ms.apply_coord_taubin_smoothing() + tbar.set_description('Isotropic explicit remeshing') + ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.AbsoluteValue(remesh_size)) + tbar.update() + + # extract mesh + m = ms.current_mesh() + verts = m.vertex_matrix() + faces = m.face_matrix() + + if logger is not None: + logger.info(f'Mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}') + + return verts, faces + + +def decimate_mesh(verts, faces, target, backend='pymeshlab', remesh=False, optimalplacement=True, logger=None): + # optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect. + + _ori_vert_shape = verts.shape + _ori_face_shape = faces.shape + + if backend == 'pyfqmr': + import pyfqmr + solver = pyfqmr.Simplify() + solver.setMesh(verts, faces) + solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False) + verts, faces, normals = solver.getMesh() + else: + + m = pml.Mesh(verts, faces) + ms = pml.MeshSet() + ms.add_mesh(m, 'mesh') # will copy! + + # filters + # ms.meshing_decimation_clustering(threshold=pml.Percentage(1)) + ms.meshing_decimation_quadric_edge_collapse(targetfacenum=int(target), optimalplacement=optimalplacement) + + if remesh: + # ms.apply_coord_taubin_smoothing() + ms.meshing_isotropic_explicit_remeshing(iterations=3, targetlen=pml.Percentage(1)) + + # extract mesh + m = ms.current_mesh() + verts = m.vertex_matrix() + faces = m.face_matrix() + + if logger is not None: + logger.info(f'Mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}') + + return verts, faces + + +def main(unused_argv): + config = configs.load_config() + config.compute_visibility = True + + config.exp_path = os.path.join("exp", config.exp_name) + config.mesh_path = os.path.join("exp", config.exp_name, "mesh") + config.checkpoint_dir = os.path.join(config.exp_path, 'checkpoints') + os.makedirs(config.mesh_path, exist_ok=True) + + # accelerator for DDP + accelerator = accelerate.Accelerator() + device = accelerator.device + + # setup logger + logging.basicConfig( + format="%(asctime)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + force=True, + handlers=[logging.StreamHandler(sys.stdout), + logging.FileHandler(os.path.join(config.exp_path, 'log_extract.txt'))], + level=logging.INFO, + ) + sys.excepthook = utils.handle_exception + logger = accelerate.logging.get_logger(__name__) + logger.info(config) + logger.info(accelerator.state, main_process_only=False) + + config.world_size = accelerator.num_processes + config.global_rank = accelerator.process_index + accelerate.utils.set_seed(config.seed, device_specific=True) + + # setup model and optimizer + model = models.Model(config=config) + model = accelerator.prepare(model) + step = checkpoints.restore_checkpoint(config.checkpoint_dir, accelerator, logger) + model.eval() + module = accelerator.unwrap_model(model) + + visibility_path = os.path.join(config.mesh_path, 'visibility_mask_{:.1f}.pt'.format(config.mesh_radius)) + visibility_resolution = config.visibility_resolution + if not os.path.exists(visibility_path): + logger.info('Generate visibility mask...') + # load dataset + dataset = datasets.load_dataset('train', config.data_dir, config) + dataloader = torch.utils.data.DataLoader(np.arange(len(dataset)), + num_workers=4, + shuffle=True, + batch_size=1, + collate_fn=dataset.collate_fn, + persistent_workers=True, + ) + + visibility_mask = torch.ones( + (1, 1, visibility_resolution, visibility_resolution, visibility_resolution), requires_grad=True + ).to(device) + visibility_mask.retain_grad() + tbar = tqdm(dataloader, desc='Generating visibility grid', disable=not accelerator.is_main_process) + for index, batch in enumerate(tbar): + batch = accelerate.utils.send_to_device(batch, accelerator.device) + + rendering = models.render_image(model, accelerator, + batch, False, 1, config, + verbose=False, return_weights=True) + + coords = rendering['coord'].reshape(-1, 3) + weights = rendering['weights'].reshape(-1) + + valid_points = coords[weights > config.valid_weight_thresh] + valid_points /= config.mesh_radius + # update mask based on ray samples + with torch.enable_grad(): + out = torch.nn.functional.grid_sample(visibility_mask, + valid_points[None, None, None], + align_corners=True) + out.sum().backward() + tbar.set_postfix({"visibility_mask": (visibility_mask.grad > 0.0001).float().mean().item()}) + # if index == 10: + # break + visibility_mask = (visibility_mask.grad > 0.0001).float() + if accelerator.is_main_process: + torch.save(visibility_mask.detach().cpu(), visibility_path) + else: + logger.info('Load visibility mask from {}'.format(visibility_path)) + visibility_mask = torch.load(visibility_path, map_location=device) + + space = config.mesh_radius * 2 / (config.visibility_resolution - 1) + + logger.info("Extract mesh from visibility mask...") + visibility_mask_np = visibility_mask[0, 0].permute(2, 1, 0).detach().cpu().numpy() + verts, faces, normals, values = measure.marching_cubes( + volume=-visibility_mask_np, + level=-0.5, + spacing=(space, space, space)) + verts -= config.mesh_radius + if config.extract_visibility: + meshexport = trimesh.Trimesh(verts, faces) + meshexport.export(os.path.join(config.mesh_path, "visibility_mask_{}.ply".format(config.mesh_radius)), "ply") + logger.info("Extract visibility mask done.") + + # Initialize variables + crop_n = 512 + grid_min = verts.min(axis=0) + grid_max = verts.max(axis=0) + space = ((grid_max - grid_min).prod() / config.mesh_voxels) ** (1 / 3) + world_size = ((grid_max - grid_min) / space).astype(np.int32) + Nx, Ny, Nz = np.maximum(1, world_size // crop_n) + crop_n_x, crop_n_y, crop_n_z = world_size // [Nx, Ny, Nz] + xs = np.linspace(grid_min[0], grid_max[0], Nx + 1) + ys = np.linspace(grid_min[1], grid_max[1], Ny + 1) + zs = np.linspace(grid_min[2], grid_max[2], Nz + 1) + # Initialize meshes list + meshes = [] + + # Iterate over the grid + for i in range(Nx): + for j in range(Ny): + for k in range(Nz): + logger.info(f"Process grid cell ({i + 1}/{Nx}, {j + 1}/{Ny}, {k + 1}/{Nz})...") + # Calculate grid cell boundaries + x_min, x_max = xs[i], xs[i + 1] + y_min, y_max = ys[j], ys[j + 1] + z_min, z_max = zs[k], zs[k + 1] + + # Create point grid + x = np.linspace(x_min, x_max, crop_n_x) + y = np.linspace(y_min, y_max, crop_n_y) + z = np.linspace(z_min, z_max, crop_n_z) + xx, yy, zz = np.meshgrid(x, y, z, indexing="ij") + points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, + dtype=torch.float, + device=device) + # Construct point pyramids + points_tmp = points.reshape(crop_n_x, crop_n_y, crop_n_z, 3)[None] + points_tmp /= config.mesh_radius + current_mask = torch.nn.functional.grid_sample(visibility_mask, points_tmp, align_corners=True) + current_mask = (current_mask > 0.0).cpu().numpy()[0, 0] + + pts_density = evaluate_density(module, accelerator, points, + config, std_value=config.std_value) + + # bound the vertices + points_world = coord.inv_contract(2 * points) + pts_density[points_world.norm(dim=-1) > config.mesh_max_radius] = 0.0 + + z = pts_density.detach().cpu().numpy() + + # Skip if no surface found + valid_z = z.reshape(crop_n_x, crop_n_y, crop_n_z)[current_mask] + if valid_z.shape[0] <= 0 or ( + np.min(valid_z) > config.isosurface_threshold or np.max( + valid_z) < config.isosurface_threshold + ): + continue + + if not (np.min(z) > config.isosurface_threshold or np.max(z) < config.isosurface_threshold): + # Extract mesh + logger.info('Extract mesh...') + z = z.astype(np.float32) + verts, faces, _, _ = measure.marching_cubes( + volume=-z.reshape(crop_n_x, crop_n_y, crop_n_z), + level=-config.isosurface_threshold, + spacing=( + (x_max - x_min) / (crop_n_x - 1), + (y_max - y_min) / (crop_n_y - 1), + (z_max - z_min) / (crop_n_z - 1), + ), + mask=current_mask, + ) + verts = verts + np.array([x_min, y_min, z_min]) + + meshcrop = trimesh.Trimesh(verts, faces) + logger.info('Extract vertices: {}, faces: {}'.format(meshcrop.vertices.shape[0], + meshcrop.faces.shape[0])) + meshes.append(meshcrop) + # Save mesh + logger.info('Concatenate mesh...') + combined_mesh = trimesh.util.concatenate(meshes) + + # from https://github.com/ashawkey/stable-dreamfusion/blob/main/nerf/renderer.py + # clean + logger.info('Clean mesh...') + vertices = combined_mesh.vertices.astype(np.float32) + faces = combined_mesh.faces.astype(np.int32) + + vertices, faces = clean_mesh(vertices, faces, + remesh=False, remesh_size=0.01, + logger=logger, main_process=accelerator.is_main_process) + + v = torch.from_numpy(vertices).contiguous().float().to(device) + v = coord.inv_contract(2 * v) + vertices = v.detach().cpu().numpy() + f = torch.from_numpy(faces).contiguous().int().to(device) + + # decimation + if config.decimate_target > 0 and faces.shape[0] > config.decimate_target: + logger.info('Decimate mesh...') + vertices, triangles = decimate_mesh(vertices, faces, config.decimate_target, logger=logger) + # import ipdb; ipdb.set_trace() + if config.vertex_color: + # batched inference to avoid OOM + logger.info('Evaluate mesh vertex color...') + if config.vertex_projection: + rgbs = evaluate_color_projection(module, accelerator, v, f, config) + else: + rgbs = evaluate_color(module, accelerator, v, + config, std_value=config.std_value) + rgbs = (rgbs * 255).detach().cpu().numpy().astype(np.uint8) + if accelerator.is_main_process: + logger.info('Export mesh (vertex color)...') + mesh = trimesh.Trimesh(vertices, faces, + vertex_colors=rgbs, + process=False) # important, process=True leads to seg fault... + mesh.export(os.path.join(config.mesh_path, 'mesh_{}.ply'.format(config.mesh_radius))) + logger.info('Finish extracting mesh.') + return + + def _export(v, f, h0=2048, w0=2048, ssaa=1, name=''): + logger.info('Export mesh (atlas)...') + # v, f: torch Tensor + device = v.device + v_np = v.cpu().numpy() # [N, 3] + f_np = f.cpu().numpy() # [M, 3] + + # unwrap uvs + import xatlas + import nvdiffrast.torch as dr + from sklearn.neighbors import NearestNeighbors + from scipy.ndimage import binary_dilation, binary_erosion + + logger.info(f'Running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}') + atlas = xatlas.Atlas() + atlas.add_mesh(v_np, f_np) + chart_options = xatlas.ChartOptions() + chart_options.max_iterations = 4 # for faster unwrap... + atlas.generate(chart_options=chart_options) + vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2] + + # vmapping, ft_np, vt_np = xatlas.parametrize(v_np, f_np) # [N], [M, 3], [N, 2] + + vt = torch.from_numpy(vt_np.astype(np.float32)).float().to(device) + ft = torch.from_numpy(ft_np.astype(np.int64)).int().to(device) + + # render uv maps + uv = vt * 2.0 - 1.0 # uvs to range [-1, 1] + uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4] + + if ssaa > 1: + h = int(h0 * ssaa) + w = int(w0 * ssaa) + else: + h, w = h0, w0 + + if h <= 2048 and w <= 2048: + glctx = dr.RasterizeCudaContext() + else: + glctx = dr.RasterizeGLContext() + + rast, _ = dr.rasterize(glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4] + xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, h, w, 3] + mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f) # [1, h, w, 1] + + # masked query + xyzs = xyzs.view(-1, 3) + mask = (mask > 0).view(-1) + + feats = torch.zeros(h * w, 3, device=device, dtype=torch.float32) + + if mask.any(): + xyzs = xyzs[mask] # [M, 3] + + # batched inference to avoid OOM + all_feats = evaluate_color(module, accelerator, xyzs, + config, std_value=config.std_value) + feats[mask] = all_feats + + feats = feats.view(h, w, -1) + mask = mask.view(h, w) + + # quantize [0.0, 1.0] to [0, 255] + feats = feats.cpu().numpy() + feats = (feats * 255).astype(np.uint8) + + ### NN search as an antialiasing ... + mask = mask.cpu().numpy() + + inpaint_region = binary_dilation(mask, iterations=3) + inpaint_region[mask] = 0 + + search_region = mask.copy() + not_search_region = binary_erosion(search_region, iterations=2) + search_region[not_search_region] = 0 + + search_coords = np.stack(np.nonzero(search_region), axis=-1) + inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1) + + knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords) + _, indices = knn.kneighbors(inpaint_coords) + + feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)] + + feats = cv2.cvtColor(feats, cv2.COLOR_RGB2BGR) + + # do ssaa after the NN search, in numpy + if ssaa > 1: + feats = cv2.resize(feats, (w0, h0), interpolation=cv2.INTER_LINEAR) + + cv2.imwrite(os.path.join(config.mesh_path, f'{name}albedo.png'), feats) + + # save obj (v, vt, f /) + obj_file = os.path.join(config.mesh_path, f'{name}mesh.obj') + mtl_file = os.path.join(config.mesh_path, f'{name}mesh.mtl') + + logger.info(f'writing obj mesh to {obj_file}') + with open(obj_file, "w") as fp: + fp.write(f'mtllib {name}mesh.mtl \n') + + logger.info(f'writing vertices {v_np.shape}') + for v in v_np: + fp.write(f'v {v[0]} {v[1]} {v[2]} \n') + + logger.info(f'writing vertices texture coords {vt_np.shape}') + for v in vt_np: + fp.write(f'vt {v[0]} {1 - v[1]} \n') + + logger.info(f'writing faces {f_np.shape}') + fp.write(f'usemtl mat0 \n') + for i in range(len(f_np)): + fp.write( + f"f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1} {f_np[i, 1] + 1}/{ft_np[i, 1] + 1} {f_np[i, 2] + 1}/{ft_np[i, 2] + 1} \n") + + with open(mtl_file, "w") as fp: + fp.write(f'newmtl mat0 \n') + fp.write(f'Ka 1.000000 1.000000 1.000000 \n') + fp.write(f'Kd 1.000000 1.000000 1.000000 \n') + fp.write(f'Ks 0.000000 0.000000 0.000000 \n') + fp.write(f'Tr 1.000000 \n') + fp.write(f'illum 1 \n') + fp.write(f'Ns 0.000000 \n') + fp.write(f'map_Kd {name}albedo.png \n') + + # could be extremely slow + _export(v, f) + + logger.info('Finish extracting mesh.') + + +if __name__ == '__main__': + with gin.config_scope('bake'): + app.run(main) diff --git a/gridencoder/__init__.py b/gridencoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f1476cef5314e0918b963d1ac64ee0613a7743d5 --- /dev/null +++ b/gridencoder/__init__.py @@ -0,0 +1 @@ +from .grid import GridEncoder \ No newline at end of file diff --git a/gridencoder/backend.py b/gridencoder/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..d99acb1f4353786e16468948780f377008d94872 --- /dev/null +++ b/gridencoder/backend.py @@ -0,0 +1,40 @@ +import os +from torch.utils.cpp_extension import load + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +_backend = load(name='_grid_encoder', + extra_cflags=c_flags, + extra_cuda_cflags=nvcc_flags, + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'gridencoder.cu', + 'bindings.cpp', + ]], + ) + +__all__ = ['_backend'] \ No newline at end of file diff --git a/gridencoder/grid.py b/gridencoder/grid.py new file mode 100644 index 0000000000000000000000000000000000000000..296e3a8e101af4b9485008b31770d5dd0d3799e9 --- /dev/null +++ b/gridencoder/grid.py @@ -0,0 +1,198 @@ +import numpy as np + +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import _gridencoder as _backend +except ImportError: + from .backend import _backend + +_gridtype_to_id = { + 'hash': 0, + 'tiled': 1, +} + +_interp_to_id = { + 'linear': 0, + 'smoothstep': 1, +} + +class _grid_encode(Function): + @staticmethod + @custom_fwd + def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False, gridtype=0, align_corners=False, interpolation=0): + # inputs: [B, D], float in [0, 1] + # embeddings: [sO, C], float + # offsets: [L + 1], int + # RETURN: [B, F], float + + inputs = inputs.contiguous() + + B, D = inputs.shape # batch size, coord dim + L = offsets.shape[0] - 1 # level + C = embeddings.shape[1] # embedding dim for each level + S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f + H = base_resolution # base resolution + + # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) + # if C % 2 != 0, force float, since half for atomicAdd is very slow. + if torch.is_autocast_enabled() and C % 2 == 0: + embeddings = embeddings.to(torch.half) + + # L first, optimize cache for cuda kernel, but needs an extra permute later + outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) + + if calc_grad_inputs: + dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=embeddings.dtype) + else: + dy_dx = None + + _backend.grid_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, dy_dx, gridtype, align_corners, interpolation) + + # permute back to [B, L * C] + outputs = outputs.permute(1, 0, 2).reshape(B, L * C) + + ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) + ctx.dims = [B, D, C, L, S, H, gridtype, interpolation] + ctx.align_corners = align_corners + + return outputs + + @staticmethod + #@once_differentiable + @custom_bwd + def backward(ctx, grad): + + inputs, embeddings, offsets, dy_dx = ctx.saved_tensors + B, D, C, L, S, H, gridtype, interpolation = ctx.dims + align_corners = ctx.align_corners + + # grad: [B, L * C] --> [L, B, C] + grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() + + grad_embeddings = torch.zeros_like(embeddings) + + if dy_dx is not None: + grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) + else: + grad_inputs = None + + _backend.grid_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interpolation) + + if dy_dx is not None: + grad_inputs = grad_inputs.to(inputs.dtype) + + return grad_inputs, grad_embeddings, None, None, None, None, None, None, None + + + +grid_encode = _grid_encode.apply + + +class GridEncoder(nn.Module): + def __init__(self, input_dim=3, num_levels=16, level_dim=2, + per_level_scale=2, base_resolution=16, + log2_hashmap_size=19, desired_resolution=None, + gridtype='hash', align_corners=False, + interpolation='linear', init_std=1e-4): + super().__init__() + + # the finest resolution desired at the last level, if provided, overridee per_level_scale + if desired_resolution is not None: + per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) + + self.input_dim = input_dim # coord dims, 2 or 3 + self.num_levels = num_levels # num levels, each level multiply resolution by 2 + self.level_dim = level_dim # encode channels per level + self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. + self.log2_hashmap_size = log2_hashmap_size + self.base_resolution = base_resolution + self.output_dim = num_levels * level_dim + self.gridtype = gridtype + self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" + self.interpolation = interpolation + self.interp_id = _interp_to_id[interpolation] # "linear" or "smoothstep" + self.align_corners = align_corners + self.init_std = init_std + + # allocate parameters + resolutions = [] + offsets = [] + offset = 0 + self.max_params = 2 ** log2_hashmap_size + for i in range(num_levels): + resolution = int(np.ceil(base_resolution * per_level_scale ** i)) + resolution = (resolution if align_corners else resolution + 1) + params_in_level = min(self.max_params, resolution ** input_dim) # limit max number + params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible + resolutions.append(resolution) + offsets.append(offset) + offset += params_in_level + offsets.append(offset) + offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) + self.register_buffer('offsets', offsets) + idx = torch.empty(offset, dtype=torch.long) + for i in range(self.num_levels): + idx[offsets[i]:offsets[i+1]] = i + self.register_buffer('idx', idx) + self.register_buffer('grid_sizes', torch.from_numpy(np.array(resolutions, dtype=np.int32))) + + self.n_params = offsets[-1] * level_dim + + # parameters + self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) + + self.reset_parameters() + + def reset_parameters(self): + std = self.init_std + self.embeddings.data.uniform_(-std, std) + + def __repr__(self): + return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners} interpolation={self.interpolation}" + + def forward(self, inputs, bound=1): + # inputs: [..., input_dim], normalized real world positions in [-bound, bound] + # return: [..., num_levels * level_dim] + + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] + # inputs = inputs.clamp(0, 1) + #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) + + prefix_shape = list(inputs.shape[:-1]) + inputs = inputs.view(-1, self.input_dim) + + outputs = grid_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad, self.gridtype_id, self.align_corners, self.interp_id) + outputs = outputs.view(prefix_shape + [self.output_dim]) + + #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) + + return outputs + + # always run in float precision! + @torch.cuda.amp.autocast(enabled=False) + def grad_total_variation(self, weight=1e-7, inputs=None, bound=1, B=1000000): + # inputs: [..., input_dim], float in [-b, b], location to calculate TV loss. + + D = self.input_dim + C = self.embeddings.shape[1] # embedding dim for each level + L = self.offsets.shape[0] - 1 # level + S = np.log2(self.per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f + H = self.base_resolution # base resolution + + if inputs is None: + # randomized in [0, 1] + inputs = torch.rand(B, self.input_dim, device=self.embeddings.device) + else: + inputs = (inputs + bound) / (2 * bound) # map to [0, 1] + inputs = inputs.view(-1, self.input_dim) + B = inputs.shape[0] + + if self.embeddings.grad is None: + raise ValueError('grad is None, should be called after loss.backward() and before optimizer.step()!') + + _backend.grad_total_variation(inputs, self.embeddings, self.embeddings.grad, self.offsets, weight, B, D, C, L, S, H, self.gridtype_id, self.align_corners) \ No newline at end of file diff --git a/gridencoder/setup.py b/gridencoder/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..714bf1cad7880fe25dca319414748c15e86cc48e --- /dev/null +++ b/gridencoder/setup.py @@ -0,0 +1,50 @@ +import os +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +_src_path = os.path.dirname(os.path.abspath(__file__)) + +nvcc_flags = [ + '-O3', '-std=c++14', + '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', +] + +if os.name == "posix": + c_flags = ['-O3', '-std=c++14'] +elif os.name == "nt": + c_flags = ['/O2', '/std:c++17'] + + # find cl.exe + def find_cl_path(): + import glob + for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: + paths = sorted(glob.glob(r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % edition), reverse=True) + if paths: + return paths[0] + + # If cl.exe is not on path, try to find it. + if os.system("where cl.exe >nul 2>nul") != 0: + cl_path = find_cl_path() + if cl_path is None: + raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation") + os.environ["PATH"] += ";" + cl_path + +setup( + name='gridencoder', # package name, import this to use python API + ext_modules=[ + CUDAExtension( + name='_gridencoder', # extension name, import this to use CUDA API + sources=[os.path.join(_src_path, 'src', f) for f in [ + 'gridencoder.cu', + 'bindings.cpp', + ]], + extra_compile_args={ + 'cxx': c_flags, + 'nvcc': nvcc_flags, + } + ), + ], + cmdclass={ + 'build_ext': BuildExtension, + } +) \ No newline at end of file diff --git a/gridencoder/src/bindings.cpp b/gridencoder/src/bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..93dea943c939cffc7ec73c76410aeff7afddc1f9 --- /dev/null +++ b/gridencoder/src/bindings.cpp @@ -0,0 +1,9 @@ +#include + +#include "gridencoder.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); + m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); + m.def("grad_total_variation", &grad_total_variation, "grad_total_variation (CUDA)"); +} \ No newline at end of file diff --git a/gridencoder/src/gridencoder.cu b/gridencoder/src/gridencoder.cu new file mode 100644 index 0000000000000000000000000000000000000000..cba5e94f5f4ca6b728bc9006c79e80cb0fce62dd --- /dev/null +++ b/gridencoder/src/gridencoder.cu @@ -0,0 +1,645 @@ +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include + + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") +#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") +#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") + + +// just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... program will never reach here! + __device__ inline at::Half atomicAdd(at::Half *address, at::Half val) { + // requires CUDA >= 10 and ARCH >= 70 + // this is very slow compared to float or __half2, never use it. + //return atomicAdd(reinterpret_cast<__half*>(address), val); +} + + +template +__host__ __device__ inline T div_round_up(T val, T divisor) { + return (val + divisor - 1) / divisor; +} + +template +__host__ __device__ inline T clamp(const T v, const T2 lo, const T2 hi) { + return min(max(v, lo), hi); +} + +template +__device__ inline T smoothstep(T val) { + return val*val*(3.0f - 2.0f * val); +} + +template +__device__ inline T smoothstep_derivative(T val) { + return 6*val*(1.0f - val); +} + + +template +__device__ uint32_t fast_hash(const uint32_t pos_grid[D]) { + + // coherent type of hashing + constexpr uint32_t primes[7] = { 1u, 2654435761u, 805459861u, 3674653429u, 2097192037u, 1434869437u, 2165219737u }; + + uint32_t result = 0; + #pragma unroll + for (uint32_t i = 0; i < D; ++i) { + result ^= pos_grid[i] * primes[i]; + } + + return result; +} + + +template +__device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) { + uint32_t stride = 1; + uint32_t index = 0; + + #pragma unroll + for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) { + index += pos_grid[d] * stride; + stride *= align_corners ? resolution: (resolution + 1); + } + + // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97. + // gridtype: 0 == hash, 1 == tiled + if (gridtype == 0 && stride > hashmap_size) { + index = fast_hash(pos_grid); + } + + return (index % hashmap_size) * C + ch; +} + + +template +__global__ void kernel_grid( + const float * __restrict__ inputs, + const scalar_t * __restrict__ grid, + const int * __restrict__ offsets, + scalar_t * __restrict__ outputs, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + scalar_t * __restrict__ dy_dx, + const uint32_t gridtype, + const bool align_corners, + const uint32_t interp +) { + const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; + + if (b >= B) return; + + const uint32_t level = blockIdx.y; + + // locate + grid += (uint32_t)offsets[level] * C; + inputs += b * D; + outputs += level * B * C + b * C; + + // check input range (should be in [0, 1]) + bool flag_oob = false; + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + flag_oob = true; + } + } + // if input out of bound, just set output to 0 + if (flag_oob) { + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + outputs[ch] = 0; + } + if (dy_dx) { + dy_dx += b * D * L * C + level * D * C; // B L D C + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + dy_dx[d * C + ch] = 0; + } + } + } + return; + } + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // calculate coordinate (always use float for precision!) + float pos[D]; + float pos_deriv[D]; + uint32_t pos_grid[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + pos[d] -= (float)pos_grid[d]; + // smoothstep instead of linear + if (interp == 1) { + pos_deriv[d] = smoothstep_derivative(pos[d]); + pos[d] = smoothstep(pos[d]); + } else { + pos_deriv[d] = 1.0f; // linear deriv is default to 1 + } + + } + + //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); + + // interpolate + scalar_t results[C] = {0}; // temp results in register + + #pragma unroll + for (uint32_t idx = 0; idx < (1 << D); idx++) { + float w = 1; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if ((idx & (1 << d)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + + // writing to register (fast) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + results[ch] += w * grid[index + ch]; + } + + //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]); + } + + // writing to global memory (slow) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + outputs[ch] = results[ch]; + } + + // prepare dy_dx + // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9 + if (dy_dx) { + + dy_dx += b * D * L * C + level * D * C; // B L D C + + #pragma unroll + for (uint32_t gd = 0; gd < D; gd++) { + + scalar_t results_grad[C] = {0}; + + #pragma unroll + for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) { + float w = scale; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t nd = 0; nd < D - 1; nd++) { + const uint32_t d = (nd >= gd) ? (nd + 1) : nd; + + if ((idx & (1 << nd)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + pos_grid_local[gd] = pos_grid[gd]; + uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + pos_grid_local[gd] = pos_grid[gd] + 1; + uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]) * pos_deriv[gd]; + } + } + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + dy_dx[gd * C + ch] = results_grad[ch]; + } + } + } +} + + +template +__global__ void kernel_grid_backward( + const scalar_t * __restrict__ grad, + const float * __restrict__ inputs, + const scalar_t * __restrict__ grid, + const int * __restrict__ offsets, + scalar_t * __restrict__ grad_grid, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + const uint32_t gridtype, + const bool align_corners, + const uint32_t interp +) { + const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C; + if (b >= B) return; + + const uint32_t level = blockIdx.y; + const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C; + + // locate + grad_grid += offsets[level] * C; + inputs += b * D; + grad += level * B * C + b * C + ch; // L, B, C + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // check input range (should be in [0, 1]) + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + return; // grad is init as 0, so we simply return. + } + } + + // calculate coordinate + float pos[D]; + uint32_t pos_grid[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + pos[d] -= (float)pos_grid[d]; + // smoothstep instead of linear + if (interp == 1) { + pos[d] = smoothstep(pos[d]); + } + } + + scalar_t grad_cur[N_C] = {0}; // fetch to register + #pragma unroll + for (uint32_t c = 0; c < N_C; c++) { + grad_cur[c] = grad[c]; + } + + // interpolate + #pragma unroll + for (uint32_t idx = 0; idx < (1 << D); idx++) { + float w = 1; + uint32_t pos_grid_local[D]; + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if ((idx & (1 << d)) == 0) { + w *= 1 - pos[d]; + pos_grid_local[d] = pos_grid[d]; + } else { + w *= pos[d]; + pos_grid_local[d] = pos_grid[d] + 1; + } + } + + uint32_t index = get_grid_index(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local); + + // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0 + // TODO: use float which is better than __half, if N_C % 2 != 0 + if (std::is_same::value && N_C % 2 == 0) { + #pragma unroll + for (uint32_t c = 0; c < N_C; c += 2) { + // process two __half at once (by interpreting as a __half2) + __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])}; + atomicAdd((__half2*)&grad_grid[index + c], v); + } + // float, or __half when N_C % 2 != 0 (which means C == 1) + } else { + #pragma unroll + for (uint32_t c = 0; c < N_C; c++) { + atomicAdd(&grad_grid[index + c], w * grad_cur[c]); + } + } + } +} + + +template +__global__ void kernel_input_backward( + const scalar_t * __restrict__ grad, + const scalar_t * __restrict__ dy_dx, + scalar_t * __restrict__ grad_inputs, + uint32_t B, uint32_t L +) { + const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; + if (t >= B * D) return; + + const uint32_t b = t / D; + const uint32_t d = t - b * D; + + dy_dx += b * L * D * C; + + scalar_t result = 0; + + # pragma unroll + for (int l = 0; l < L; l++) { + # pragma unroll + for (int ch = 0; ch < C; ch++) { + result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch]; + } + } + + grad_inputs[t] = result; +} + + +template +void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + static constexpr uint32_t N_THREAD = 512; + const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; + switch (C) { + case 1: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 2: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 4: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 8: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, dy_dx, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.) +// H: base resolution +// dy_dx: [B, L * D * C] +template +void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + switch (D) { + case 2: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 3: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 4: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + case 5: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, dy_dx, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + +template +void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + static constexpr uint32_t N_THREAD = 256; + const uint32_t N_C = std::min(2u, C); // n_features_per_thread + const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 }; + switch (C) { + case 1: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 2: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 4: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + case 8: + kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners, interp); + if (dy_dx) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); + break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +// grad: [L, B, C], float +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// grad_embeddings: [sO, C] +// H: base resolution +template +void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + switch (D) { + case 2: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 3: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 4: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + case 5: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, dy_dx, grad_inputs, gridtype, align_corners, interp); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + + +void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + CHECK_CUDA(inputs); + CHECK_CUDA(embeddings); + CHECK_CUDA(offsets); + CHECK_CUDA(outputs); + // CHECK_CUDA(dy_dx); + + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(embeddings); + CHECK_CONTIGUOUS(offsets); + CHECK_CONTIGUOUS(outputs); + // CHECK_CONTIGUOUS(dy_dx); + + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(embeddings); + CHECK_IS_INT(offsets); + CHECK_IS_FLOATING(outputs); + // CHECK_IS_FLOATING(dy_dx); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + embeddings.scalar_type(), "grid_encode_forward", ([&] { + grid_encode_forward_cuda(inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, gridtype, align_corners, interp); + })); +} + +void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp) { + CHECK_CUDA(grad); + CHECK_CUDA(inputs); + CHECK_CUDA(embeddings); + CHECK_CUDA(offsets); + CHECK_CUDA(grad_embeddings); + // CHECK_CUDA(dy_dx); + // CHECK_CUDA(grad_inputs); + + CHECK_CONTIGUOUS(grad); + CHECK_CONTIGUOUS(inputs); + CHECK_CONTIGUOUS(embeddings); + CHECK_CONTIGUOUS(offsets); + CHECK_CONTIGUOUS(grad_embeddings); + // CHECK_CONTIGUOUS(dy_dx); + // CHECK_CONTIGUOUS(grad_inputs); + + CHECK_IS_FLOATING(grad); + CHECK_IS_FLOATING(inputs); + CHECK_IS_FLOATING(embeddings); + CHECK_IS_INT(offsets); + CHECK_IS_FLOATING(grad_embeddings); + // CHECK_IS_FLOATING(dy_dx); + // CHECK_IS_FLOATING(grad_inputs); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "grid_encode_backward", ([&] { + grid_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), B, D, C, L, S, H, dy_dx.has_value() ? dy_dx.value().data_ptr() : nullptr, grad_inputs.has_value() ? grad_inputs.value().data_ptr() : nullptr, gridtype, align_corners, interp); + })); + +} + + +template +__global__ void kernel_grad_tv( + const scalar_t * __restrict__ inputs, + const scalar_t * __restrict__ grid, + scalar_t * __restrict__ grad, + const int * __restrict__ offsets, + const float weight, + const uint32_t B, const uint32_t L, const float S, const uint32_t H, + const uint32_t gridtype, + const bool align_corners +) { + const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; + + if (b >= B) return; + + const uint32_t level = blockIdx.y; + + // locate + inputs += b * D; + grid += (uint32_t)offsets[level] * C; + grad += (uint32_t)offsets[level] * C; + + // check input range (should be in [0, 1]) + bool flag_oob = false; + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + if (inputs[d] < 0 || inputs[d] > 1) { + flag_oob = true; + } + } + + // if input out of bound, do nothing + if (flag_oob) return; + + const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; + const float scale = exp2f(level * S) * H - 1.0f; + const uint32_t resolution = (uint32_t)ceil(scale) + 1; + + // calculate coordinate + float pos[D]; + uint32_t pos_grid[D]; // [0, resolution] + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); + pos_grid[d] = floorf(pos[d]); + // pos[d] -= (float)pos_grid[d]; // not used + } + + //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); + + // total variation on pos_grid + scalar_t results[C] = {0}; // temp results in register + scalar_t idelta[C] = {0}; + + uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); + + scalar_t w = weight / (2 * D); + + #pragma unroll + for (uint32_t d = 0; d < D; d++) { + + uint32_t cur_d = pos_grid[d]; + scalar_t grad_val; + + // right side + if (cur_d < resolution) { + pos_grid[d] = cur_d + 1; + uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + // results[ch] += w * clamp(grid[index + ch] - grid[index_right + ch], -1.0f, 1.0f); + grad_val = (grid[index + ch] - grid[index_right + ch]); + results[ch] += grad_val; + idelta[ch] += grad_val * grad_val; + } + } + + // left side + if (cur_d > 0) { + pos_grid[d] = cur_d - 1; + uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid); + + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + // results[ch] += w * clamp(grid[index + ch] - grid[index_left + ch], -1.0f, 1.0f); + grad_val = (grid[index + ch] - grid[index_left + ch]); + results[ch] += grad_val; + idelta[ch] += grad_val * grad_val; + } + } + + // reset + pos_grid[d] = cur_d; + } + + // writing to global memory (slow) + #pragma unroll + for (uint32_t ch = 0; ch < C; ch++) { + // index may collide, so use atomic! + atomicAdd(&grad[index + ch], w * results[ch] * rsqrtf(idelta[ch] + 1e-9f)); + } + +} + + +template +void kernel_grad_tv_wrapper(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + static constexpr uint32_t N_THREAD = 512; + const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; + switch (C) { + case 1: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 2: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 4: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + case 8: kernel_grad_tv<<>>(inputs, embeddings, grad, offsets, weight, B, L, S, H, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +template +void grad_total_variation_cuda(const scalar_t *inputs, const scalar_t *embeddings, scalar_t *grad, const int *offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + switch (D) { + case 2: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 3: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 4: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + case 5: kernel_grad_tv_wrapper(inputs, embeddings, grad, offsets, weight, B, C, L, S, H, gridtype, align_corners); break; + default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; + } +} + + +void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners) { + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + embeddings.scalar_type(), "grad_total_variation", ([&] { + grad_total_variation_cuda(inputs.data_ptr(), embeddings.data_ptr(), grad.data_ptr(), offsets.data_ptr(), weight, B, D, C, L, S, H, gridtype, align_corners); + })); +} \ No newline at end of file diff --git a/gridencoder/src/gridencoder.h b/gridencoder/src/gridencoder.h new file mode 100644 index 0000000000000000000000000000000000000000..1b385755d13711b04df4866dd654e88b48054554 --- /dev/null +++ b/gridencoder/src/gridencoder.h @@ -0,0 +1,17 @@ +#ifndef _HASH_ENCODE_H +#define _HASH_ENCODE_H + +#include +#include + +// inputs: [B, D], float, in [0, 1] +// embeddings: [sO, C], float +// offsets: [L + 1], uint32_t +// outputs: [B, L * C], float +// H: base resolution +void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, at::optional dy_dx, const uint32_t gridtype, const bool align_corners, const uint32_t interp); +void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const at::optional dy_dx, at::optional grad_inputs, const uint32_t gridtype, const bool align_corners, const uint32_t interp); + +void grad_total_variation(const at::Tensor inputs, const at::Tensor embeddings, at::Tensor grad, const at::Tensor offsets, const float weight, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const uint32_t gridtype, const bool align_corners); + +#endif \ No newline at end of file diff --git a/internal/camera_utils.py b/internal/camera_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a794a47d6cd427bd8c2cf988a2a840a15278bf43 --- /dev/null +++ b/internal/camera_utils.py @@ -0,0 +1,673 @@ +import enum +from internal import configs +from internal import stepfun +from internal import utils +import numpy as np +import scipy + + +def convert_to_ndc(origins, + directions, + pixtocam, + near: float = 1.): + """Converts a set of rays to normalized device coordinates (NDC). + + Args: + origins: ndarray(float32), [..., 3], world space ray origins. + directions: ndarray(float32), [..., 3], world space ray directions. + pixtocam: ndarray(float32), [3, 3], inverse intrinsic matrix. + near: float, near plane along the negative z axis. + + Returns: + origins_ndc: ndarray(float32), [..., 3]. + directions_ndc: ndarray(float32), [..., 3]. + + This function assumes input rays should be mapped into the NDC space for a + perspective projection pinhole camera, with identity extrinsic matrix (pose) + and intrinsic parameters defined by inputs focal, width, and height. + + The near value specifies the near plane of the frustum, and the far plane is + assumed to be infinity. + + The ray bundle for the identity pose camera will be remapped to parallel rays + within the (-1, -1, -1) to (1, 1, 1) cube. Any other ray in the original + world space can be remapped as long as it has dz < 0 (ray direction has a + negative z-coord); this allows us to share a common NDC space for "forward + facing" scenes. + + Note that + projection(origins + t * directions) + will NOT be equal to + origins_ndc + t * directions_ndc + and that the directions_ndc are not unit length. Rather, directions_ndc is + defined such that the valid near and far planes in NDC will be 0 and 1. + + See Appendix C in https://arxiv.org/abs/2003.08934 for additional details. + """ + + # Shift ray origins to near plane, such that oz = -near. + # This makes the new near bound equal to 0. + t = -(near + origins[..., 2]) / directions[..., 2] + origins = origins + t[..., None] * directions + + dx, dy, dz = np.moveaxis(directions, -1, 0) + ox, oy, oz = np.moveaxis(origins, -1, 0) + + xmult = 1. / pixtocam[0, 2] # Equal to -2. * focal / cx + ymult = 1. / pixtocam[1, 2] # Equal to -2. * focal / cy + + # Perspective projection into NDC for the t = 0 near points + # origins + 0 * directions + origins_ndc = np.stack([xmult * ox / oz, ymult * oy / oz, + -np.ones_like(oz)], axis=-1) + + # Perspective projection into NDC for the t = infinity far points + # origins + infinity * directions + infinity_ndc = np.stack([xmult * dx / dz, ymult * dy / dz, + np.ones_like(oz)], + axis=-1) + + # directions_ndc points from origins_ndc to infinity_ndc + directions_ndc = infinity_ndc - origins_ndc + + return origins_ndc, directions_ndc + + +def pad_poses(p): + """Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1].""" + bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape) + return np.concatenate([p[..., :3, :4], bottom], axis=-2) + + +def unpad_poses(p): + """Remove the homogeneous bottom row from [..., 4, 4] pose matrices.""" + return p[..., :3, :4] + + +def recenter_poses(poses): + """Recenter poses around the origin.""" + cam2world = average_pose(poses) + transform = np.linalg.inv(pad_poses(cam2world)) + poses = transform @ pad_poses(poses) + return unpad_poses(poses), transform + + +def average_pose(poses): + """New pose using average position, z-axis, and up vector of input poses.""" + position = poses[:, :3, 3].mean(0) + z_axis = poses[:, :3, 2].mean(0) + up = poses[:, :3, 1].mean(0) + cam2world = viewmatrix(z_axis, up, position) + return cam2world + + +def viewmatrix(lookdir, up, position): + """Construct lookat view matrix.""" + vec2 = normalize(lookdir) + vec0 = normalize(np.cross(up, vec2)) + vec1 = normalize(np.cross(vec2, vec0)) + m = np.stack([vec0, vec1, vec2, position], axis=1) + return m + + +def normalize(x): + """Normalization helper function.""" + return x / np.linalg.norm(x) + + +def focus_point_fn(poses): + """Calculate nearest point to all focal axes in poses.""" + directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4] + m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1]) + mt_m = np.transpose(m, [0, 2, 1]) @ m + focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0] + return focus_pt + + +# Constants for generate_spiral_path(): +NEAR_STRETCH = .9 # Push forward near bound for forward facing render path. +FAR_STRETCH = 5. # Push back far bound for forward facing render path. +FOCUS_DISTANCE = .75 # Relative weighting of near, far bounds for render path. + + +def generate_spiral_path(poses, bounds, n_frames=120, n_rots=2, zrate=.5): + """Calculates a forward facing spiral path for rendering.""" + # Find a reasonable 'focus depth' for this dataset as a weighted average + # of conservative near and far bounds in disparity space. + near_bound = bounds.min() * NEAR_STRETCH + far_bound = bounds.max() * FAR_STRETCH + # All cameras will point towards the world space point (0, 0, -focal). + focal = 1 / (((1 - FOCUS_DISTANCE) / near_bound + FOCUS_DISTANCE / far_bound)) + + # Get radii for spiral path using 90th percentile of camera positions. + positions = poses[:, :3, 3] + radii = np.percentile(np.abs(positions), 90, 0) + radii = np.concatenate([radii, [1.]]) + + # Generate poses for spiral path. + render_poses = [] + cam2world = average_pose(poses) + up = poses[:, :3, 1].mean(0) + for theta in np.linspace(0., 2. * np.pi * n_rots, n_frames, endpoint=False): + t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.] + position = cam2world @ t + lookat = cam2world @ [0, 0, -focal, 1.] + z_axis = position - lookat + render_poses.append(viewmatrix(z_axis, up, position)) + render_poses = np.stack(render_poses, axis=0) + return render_poses + + +def transform_poses_pca(poses): + """Transforms poses so principal components lie on XYZ axes. + + Args: + poses: a (N, 3, 4) array containing the cameras' camera to world transforms. + + Returns: + A tuple (poses, transform), with the transformed poses and the applied + camera_to_world transforms. + """ + t = poses[:, :3, 3] + t_mean = t.mean(axis=0) + t = t - t_mean + + eigval, eigvec = np.linalg.eig(t.T @ t) + # Sort eigenvectors in order of largest to smallest eigenvalue. + inds = np.argsort(eigval)[::-1] + eigvec = eigvec[:, inds] + rot = eigvec.T + if np.linalg.det(rot) < 0: + rot = np.diag(np.array([1, 1, -1])) @ rot + + transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1) + poses_recentered = unpad_poses(transform @ pad_poses(poses)) + transform = np.concatenate([transform, np.eye(4)[3:]], axis=0) + + # Flip coordinate system if z component of y-axis is negative + if poses_recentered.mean(axis=0)[2, 1] < 0: + poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered + transform = np.diag(np.array([1, -1, -1, 1])) @ transform + + # Just make sure it's it in the [-1, 1]^3 cube + scale_factor = 1. / np.max(np.abs(poses_recentered[:, :3, 3])) + poses_recentered[:, :3, 3] *= scale_factor + transform = np.diag(np.array([scale_factor] * 3 + [1])) @ transform + + return poses_recentered, transform + + +def generate_ellipse_path(poses, n_frames=120, const_speed=True, z_variation=0., z_phase=0.): + """Generate an elliptical render path based on the given poses.""" + # Calculate the focal point for the path (cameras point toward this). + center = focus_point_fn(poses) + # Path height sits at z=0 (in middle of zero-mean capture pattern). + offset = np.array([center[0], center[1], 0]) + + # Calculate scaling for ellipse axes based on input camera positions. + sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0) + # Use ellipse that is symmetric about the focal point in xy. + low = -sc + offset + high = sc + offset + # Optional height variation need not be symmetric + z_low = np.percentile((poses[:, :3, 3]), 10, axis=0) + z_high = np.percentile((poses[:, :3, 3]), 90, axis=0) + + def get_positions(theta): + # Interpolate between bounds with trig functions to get ellipse in x-y. + # Optionally also interpolate in z to change camera height along path. + return np.stack([ + low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5), + low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5), + z_variation * (z_low[2] + (z_high - z_low)[2] * + (np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)), + ], -1) + + theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True) + positions = get_positions(theta) + + if const_speed: + # Resample theta angles so that the velocity is closer to constant. + lengths = np.linalg.norm(positions[1:] - positions[:-1], axis=-1) + theta = stepfun.sample_np(None, theta, np.log(lengths), n_frames + 1) + positions = get_positions(theta) + + # Throw away duplicated last position. + positions = positions[:-1] + + # Set path's up vector to axis closest to average of input pose up vectors. + avg_up = poses[:, :3, 1].mean(0) + avg_up = avg_up / np.linalg.norm(avg_up) + ind_up = np.argmax(np.abs(avg_up)) + up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up]) + + return np.stack([viewmatrix(p - center, up, p) for p in positions]) + + +def generate_interpolated_path(poses, n_interp, spline_degree=5, + smoothness=.03, rot_weight=.1): + """Creates a smooth spline path between input keyframe camera poses. + + Spline is calculated with poses in format (position, lookat-point, up-point). + + Args: + poses: (n, 3, 4) array of input pose keyframes. + n_interp: returned path will have n_interp * (n - 1) total poses. + spline_degree: polynomial degree of B-spline. + smoothness: parameter for spline smoothing, 0 forces exact interpolation. + rot_weight: relative weighting of rotation/translation in spline solve. + + Returns: + Array of new camera poses with shape (n_interp * (n - 1), 3, 4). + """ + + def poses_to_points(poses, dist): + """Converts from pose matrices to (position, lookat, up) format.""" + pos = poses[:, :3, -1] + lookat = poses[:, :3, -1] - dist * poses[:, :3, 2] + up = poses[:, :3, -1] + dist * poses[:, :3, 1] + return np.stack([pos, lookat, up], 1) + + def points_to_poses(points): + """Converts from (position, lookat, up) format to pose matrices.""" + return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points]) + + def interp(points, n, k, s): + """Runs multidimensional B-spline interpolation on the input points.""" + sh = points.shape + pts = np.reshape(points, (sh[0], -1)) + k = min(k, sh[0] - 1) + tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s) + u = np.linspace(0, 1, n, endpoint=False) + new_points = np.array(scipy.interpolate.splev(u, tck)) + new_points = np.reshape(new_points.T, (n, sh[1], sh[2])) + return new_points + + points = poses_to_points(poses, dist=rot_weight) + new_points = interp(points, + n_interp * (points.shape[0] - 1), + k=spline_degree, + s=smoothness) + return points_to_poses(new_points) + + +def interpolate_1d(x, n_interp, spline_degree, smoothness): + """Interpolate 1d signal x (by a factor of n_interp times).""" + t = np.linspace(0, 1, len(x), endpoint=True) + tck = scipy.interpolate.splrep(t, x, s=smoothness, k=spline_degree) + n = n_interp * (len(x) - 1) + u = np.linspace(0, 1, n, endpoint=False) + return scipy.interpolate.splev(u, tck) + + +def create_render_spline_path(config, image_names, poses, exposures): + """Creates spline interpolation render path from subset of dataset poses. + + Args: + config: configs.Config object. + image_names: either a directory of images or a text file of image names. + poses: [N, 3, 4] array of extrinsic camera pose matrices. + exposures: optional list of floating point exposure values. + + Returns: + spline_indices: list of indices used to select spline keyframe poses. + render_poses: array of interpolated extrinsic camera poses for the path. + render_exposures: optional list of interpolated exposures for the path. + """ + if utils.isdir(config.render_spline_keyframes): + # If directory, use image filenames. + keyframe_names = sorted(utils.listdir(config.render_spline_keyframes)) + else: + # If text file, treat each line as an image filename. + with utils.open_file(config.render_spline_keyframes, 'r') as fp: + # Decode bytes into string and split into lines. + keyframe_names = fp.read().decode('utf-8').splitlines() + # Grab poses corresponding to the image filenames. + spline_indices = np.array( + [i for i, n in enumerate(image_names) if n in keyframe_names]) + keyframes = poses[spline_indices] + render_poses = generate_interpolated_path( + keyframes, + n_interp=config.render_spline_n_interp, + spline_degree=config.render_spline_degree, + smoothness=config.render_spline_smoothness, + rot_weight=.1) + if config.render_spline_interpolate_exposure: + if exposures is None: + raise ValueError('config.render_spline_interpolate_exposure is True but ' + 'create_render_spline_path() was passed exposures=None.') + # Interpolate per-frame exposure value. + log_exposure = np.log(exposures[spline_indices]) + # Use aggressive smoothing for exposure interpolation to avoid flickering. + log_exposure_interp = interpolate_1d( + log_exposure, + config.render_spline_n_interp, + spline_degree=5, + smoothness=20) + render_exposures = np.exp(log_exposure_interp) + else: + render_exposures = None + return spline_indices, render_poses, render_exposures + + +def intrinsic_matrix(fx, fy, cx, cy): + """Intrinsic matrix for a pinhole camera in OpenCV coordinate system.""" + return np.array([ + [fx, 0, cx], + [0, fy, cy], + [0, 0, 1.], + ]) + + +def get_pixtocam(focal, width, height): + """Inverse intrinsic matrix for a perfect pinhole camera.""" + camtopix = intrinsic_matrix(focal, focal, width * .5, height * .5) + return np.linalg.inv(camtopix) + + +def pixel_coordinates(width, height): + """Tuple of the x and y integer coordinates for a grid of pixels.""" + return np.meshgrid(np.arange(width), np.arange(height), indexing='xy') + + +def _compute_residual_and_jacobian(x, y, xd, yd, + k1=0.0, k2=0.0, k3=0.0, + k4=0.0, p1=0.0, p2=0.0, ): + """Auxiliary function of radial_and_tangential_undistort().""" + # Adapted from https://github.com/google/nerfies/blob/main/nerfies/camera.py + # let r(x, y) = x^2 + y^2; + # d(x, y) = 1 + k1 * r(x, y) + k2 * r(x, y) ^2 + k3 * r(x, y)^3 + + # k4 * r(x, y)^4; + r = x * x + y * y + d = 1.0 + r * (k1 + r * (k2 + r * (k3 + r * k4))) + + # The perfect projection is: + # xd = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2); + # yd = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2); + # + # Let's define + # + # fx(x, y) = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2) - xd; + # fy(x, y) = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2) - yd; + # + # We are looking for a solution that satisfies + # fx(x, y) = fy(x, y) = 0; + fx = d * x + 2 * p1 * x * y + p2 * (r + 2 * x * x) - xd + fy = d * y + 2 * p2 * x * y + p1 * (r + 2 * y * y) - yd + + # Compute derivative of d over [x, y] + d_r = (k1 + r * (2.0 * k2 + r * (3.0 * k3 + r * 4.0 * k4))) + d_x = 2.0 * x * d_r + d_y = 2.0 * y * d_r + + # Compute derivative of fx over x and y. + fx_x = d + d_x * x + 2.0 * p1 * y + 6.0 * p2 * x + fx_y = d_y * x + 2.0 * p1 * x + 2.0 * p2 * y + + # Compute derivative of fy over x and y. + fy_x = d_x * y + 2.0 * p2 * y + 2.0 * p1 * x + fy_y = d + d_y * y + 2.0 * p2 * x + 6.0 * p1 * y + + return fx, fy, fx_x, fx_y, fy_x, fy_y + + +def _radial_and_tangential_undistort(xd, yd, k1=0, k2=0, + k3=0, k4=0, p1=0, + p2=0, eps=1e-9, max_iterations=10): + """Computes undistorted (x, y) from (xd, yd).""" + # From https://github.com/google/nerfies/blob/main/nerfies/camera.py + # Initialize from the distorted point. + x = np.copy(xd) + y = np.copy(yd) + + for _ in range(max_iterations): + fx, fy, fx_x, fx_y, fy_x, fy_y = _compute_residual_and_jacobian( + x=x, y=y, xd=xd, yd=yd, k1=k1, k2=k2, k3=k3, k4=k4, p1=p1, p2=p2) + denominator = fy_x * fx_y - fx_x * fy_y + x_numerator = fx * fy_y - fy * fx_y + y_numerator = fy * fx_x - fx * fy_x + step_x = np.where( + np.abs(denominator) > eps, x_numerator / denominator, + np.zeros_like(denominator)) + step_y = np.where( + np.abs(denominator) > eps, y_numerator / denominator, + np.zeros_like(denominator)) + + x = x + step_x + y = y + step_y + + return x, y + + +class ProjectionType(enum.Enum): + """Camera projection type (standard perspective pinhole or fisheye model).""" + PERSPECTIVE = 'perspective' + FISHEYE = 'fisheye' + + +def pixels_to_rays(pix_x_int, pix_y_int, pixtocams, + camtoworlds, + distortion_params=None, + pixtocam_ndc=None, + camtype=ProjectionType.PERSPECTIVE): + """Calculates rays given pixel coordinates, intrinisics, and extrinsics. + + Given 2D pixel coordinates pix_x_int, pix_y_int for cameras with + inverse intrinsics pixtocams and extrinsics camtoworlds (and optional + distortion coefficients distortion_params and NDC space projection matrix + pixtocam_ndc), computes the corresponding 3D camera rays. + + Vectorized over the leading dimensions of the first four arguments. + + Args: + pix_x_int: int array, shape SH, x coordinates of image pixels. + pix_y_int: int array, shape SH, y coordinates of image pixels. + pixtocams: float array, broadcastable to SH + [3, 3], inverse intrinsics. + camtoworlds: float array, broadcastable to SH + [3, 4], camera extrinsics. + distortion_params: dict of floats, optional camera distortion parameters. + pixtocam_ndc: float array, [3, 3], optional inverse intrinsics for NDC. + camtype: camera_utils.ProjectionType, fisheye or perspective camera. + + Returns: + origins: float array, shape SH + [3], ray origin points. + directions: float array, shape SH + [3], ray direction vectors. + viewdirs: float array, shape SH + [3], normalized ray direction vectors. + radii: float array, shape SH + [1], ray differential radii. + imageplane: float array, shape SH + [2], xy coordinates on the image plane. + If the image plane is at world space distance 1 from the pinhole, then + imageplane will be the xy coordinates of a pixel in that space (so the + camera ray direction at the origin would be (x, y, -1) in OpenGL coords). + """ + + # Must add half pixel offset to shoot rays through pixel centers. + def pix_to_dir(x, y): + return np.stack([x + .5, y + .5, np.ones_like(x)], axis=-1) + + # We need the dx and dy rays to calculate ray radii for mip-NeRF cones. + pixel_dirs_stacked = np.stack([ + pix_to_dir(pix_x_int, pix_y_int), + pix_to_dir(pix_x_int + 1, pix_y_int), + pix_to_dir(pix_x_int, pix_y_int + 1) + ], axis=0) + + matmul = np.matmul + mat_vec_mul = lambda A, b: matmul(A, b[..., None])[..., 0] + + # Apply inverse intrinsic matrices. + camera_dirs_stacked = mat_vec_mul(pixtocams, pixel_dirs_stacked) + + if distortion_params is not None: + # Correct for distortion. + x, y = _radial_and_tangential_undistort( + camera_dirs_stacked[..., 0], + camera_dirs_stacked[..., 1], + **distortion_params) + camera_dirs_stacked = np.stack([x, y, np.ones_like(x)], -1) + + if camtype == ProjectionType.FISHEYE: + theta = np.sqrt(np.sum(np.square(camera_dirs_stacked[..., :2]), axis=-1)) + theta = np.minimum(np.pi, theta) + + sin_theta_over_theta = np.sin(theta) / theta + camera_dirs_stacked = np.stack([ + camera_dirs_stacked[..., 0] * sin_theta_over_theta, + camera_dirs_stacked[..., 1] * sin_theta_over_theta, + np.cos(theta), + ], axis=-1) + + # Flip from OpenCV to OpenGL coordinate system. + camera_dirs_stacked = matmul(camera_dirs_stacked, + np.diag(np.array([1., -1., -1.]))) + + # Extract 2D image plane (x, y) coordinates. + imageplane = camera_dirs_stacked[0, ..., :2] + + # Apply camera rotation matrices. + directions_stacked = mat_vec_mul(camtoworlds[..., :3, :3], + camera_dirs_stacked) + # Extract the offset rays. + directions, dx, dy = directions_stacked + + origins = np.broadcast_to(camtoworlds[..., :3, -1], directions.shape) + viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True) + + if pixtocam_ndc is None: + # Distance from each unit-norm direction vector to its neighbors. + dx_norm = np.linalg.norm(dx - directions, axis=-1) + dy_norm = np.linalg.norm(dy - directions, axis=-1) + + else: + # Convert ray origins and directions into projective NDC space. + origins_dx, _ = convert_to_ndc(origins, dx, pixtocam_ndc) + origins_dy, _ = convert_to_ndc(origins, dy, pixtocam_ndc) + origins, directions = convert_to_ndc(origins, directions, pixtocam_ndc) + + # In NDC space, we use the offset between origins instead of directions. + dx_norm = np.linalg.norm(origins_dx - origins, axis=-1) + dy_norm = np.linalg.norm(origins_dy - origins, axis=-1) + + # Cut the distance in half, multiply it to match the variance of a uniform + # distribution the size of a pixel (1/12, see the original mipnerf paper). + radii = (0.5 * (dx_norm + dy_norm))[..., None] * 2 / np.sqrt(12) + return origins, directions, viewdirs, radii, imageplane + + +def cast_ray_batch(cameras, pixels, camtype): + """Maps from input cameras and Pixel batch to output Ray batch. + + `cameras` is a Tuple of four sets of camera parameters. + pixtocams: 1 or N stacked [3, 3] inverse intrinsic matrices. + camtoworlds: 1 or N stacked [3, 4] extrinsic pose matrices. + distortion_params: optional, dict[str, float] containing pinhole model + distortion parameters. + pixtocam_ndc: optional, [3, 3] inverse intrinsic matrix for mapping to NDC. + + Args: + cameras: described above. + pixels: integer pixel coordinates and camera indices, plus ray metadata. + These fields can be an arbitrary batch shape. + camtype: camera_utils.ProjectionType, fisheye or perspective camera. + + Returns: + rays: Rays dataclass with computed 3D world space ray data. + """ + pixtocams, camtoworlds, distortion_params, pixtocam_ndc = cameras + + # pixels.cam_idx has shape [..., 1], remove this hanging dimension. + cam_idx = pixels['cam_idx'][..., 0] + batch_index = lambda arr: arr if arr.ndim == 2 else arr[cam_idx] + + # Compute rays from pixel coordinates. + origins, directions, viewdirs, radii, imageplane = pixels_to_rays( + pixels['pix_x_int'], + pixels['pix_y_int'], + batch_index(pixtocams), + batch_index(camtoworlds), + distortion_params=distortion_params, + pixtocam_ndc=pixtocam_ndc, + camtype=camtype) + + # Create Rays data structure. + return dict( + origins=origins, + directions=directions, + viewdirs=viewdirs, + radii=radii, + imageplane=imageplane, + lossmult=pixels.get('lossmult'), + near=pixels.get('near'), + far=pixels.get('far'), + cam_idx=pixels.get('cam_idx'), + exposure_idx=pixels.get('exposure_idx'), + exposure_values=pixels.get('exposure_values'), + ) + + +def cast_pinhole_rays(camtoworld, height, width, focal, near, far): + """Wrapper for generating a pinhole camera ray batch (w/o distortion).""" + + pix_x_int, pix_y_int = pixel_coordinates(width, height) + pixtocam = get_pixtocam(focal, width, height) + + origins, directions, viewdirs, radii, imageplane = pixels_to_rays(pix_x_int, pix_y_int, pixtocam, camtoworld) + + broadcast_scalar = lambda x: np.broadcast_to(x, pix_x_int.shape)[..., None] + ray_kwargs = { + 'lossmult': broadcast_scalar(1.), + 'near': broadcast_scalar(near), + 'far': broadcast_scalar(far), + 'cam_idx': broadcast_scalar(0), + } + + return dict(origins=origins, + directions=directions, + viewdirs=viewdirs, + radii=radii, + imageplane=imageplane, + **ray_kwargs) + + +def cast_spherical_rays(camtoworld, height, width, near, far): + """Generates a spherical camera ray batch.""" + + theta_vals = np.linspace(0, 2 * np.pi, width + 1) + phi_vals = np.linspace(0, np.pi, height + 1) + theta, phi = np.meshgrid(theta_vals, phi_vals, indexing='xy') + + # Spherical coordinates in camera reference frame (y is up). + directions = np.stack([ + -np.sin(phi) * np.sin(theta), + np.cos(phi), + np.sin(phi) * np.cos(theta), + ], axis=-1) + + matmul = np.matmul + directions = matmul(camtoworld[:3, :3], directions[..., None])[..., 0] + + dy = np.diff(directions[:, :-1], axis=0) + dx = np.diff(directions[:-1, :], axis=1) + directions = directions[:-1, :-1] + viewdirs = directions + + origins = np.broadcast_to(camtoworld[:3, -1], directions.shape) + + dx_norm = np.linalg.norm(dx, axis=-1) + dy_norm = np.linalg.norm(dy, axis=-1) + radii = (0.5 * (dx_norm + dy_norm))[..., None] * 2 / np.sqrt(12) + + imageplane = np.zeros_like(directions[..., :2]) + + broadcast_scalar = lambda x: np.broadcast_to(x, radii.shape[:-1])[..., None] + ray_kwargs = { + 'lossmult': broadcast_scalar(1.), + 'near': broadcast_scalar(near), + 'far': broadcast_scalar(far), + 'cam_idx': broadcast_scalar(0), + } + + return dict(origins=origins, + directions=directions, + viewdirs=viewdirs, + radii=radii, + imageplane=imageplane, + **ray_kwargs) diff --git a/internal/checkpoints.py b/internal/checkpoints.py new file mode 100644 index 0000000000000000000000000000000000000000..aa8a174a3ff02b5a9b6b6bba7e514bac0c1bf5c2 --- /dev/null +++ b/internal/checkpoints.py @@ -0,0 +1,38 @@ +import os +import shutil + +import accelerate +import torch +import glob + + +def restore_checkpoint( + checkpoint_dir, + accelerator: accelerate.Accelerator, + logger=None +): + dirs = glob.glob(os.path.join(checkpoint_dir, "*")) + dirs.sort() + path = dirs[-1] if len(dirs) > 0 else None + if path is None: + if logger is not None: + logger.info("Checkpoint does not exist. Starting a new training run.") + init_step = 0 + else: + if logger is not None: + logger.info(f"Resuming from checkpoint {path}") + accelerator.load_state(path) + init_step = int(os.path.basename(path)) + return init_step + + +def save_checkpoint(save_dir, + accelerator: accelerate.Accelerator, + step=0, + total_limit=3): + if total_limit > 0: + folders = glob.glob(os.path.join(save_dir, "*")) + folders.sort() + for folder in folders[: len(folders) + 1 - total_limit]: + shutil.rmtree(folder) + accelerator.save_state(os.path.join(save_dir, f"{step:06d}")) diff --git a/internal/configs.py b/internal/configs.py new file mode 100644 index 0000000000000000000000000000000000000000..4197bd982993c9667fbf655fe2600998815e5a63 --- /dev/null +++ b/internal/configs.py @@ -0,0 +1,177 @@ +import dataclasses +import os +from typing import Any, Callable, Optional, Tuple, List +import numpy as np +import torch +import torch.nn.functional as F +from absl import flags +import gin +from internal import utils + +gin.add_config_file_search_path('configs/') + +configurables = { + 'torch': [torch.reciprocal, torch.log, torch.log1p, torch.exp, torch.sqrt, torch.square], +} + +for module, configurables in configurables.items(): + for configurable in configurables: + gin.config.external_configurable(configurable, module=module) + + +@gin.configurable() +@dataclasses.dataclass +class Config: + """Configuration flags for everything.""" + seed = 0 + dataset_loader: str = 'llff' # The type of dataset loader to use. + batching: str = 'all_images' # Batch composition, [single_image, all_images]. + batch_size: int = 2 ** 16 # The number of rays/pixels in each batch. + patch_size: int = 1 # Resolution of patches sampled for training batches. + factor: int = 4 # The downsample factor of images, 0 for no downsampling. + multiscale: bool = False # use multiscale data for training. + multiscale_levels: int = 4 # number of multiscale levels. + # ordering (affects heldout test set). + forward_facing: bool = False # Set to True for forward-facing LLFF captures. + render_path: bool = False # If True, render a path. Used only by LLFF. + llffhold: int = 8 # Use every Nth image for the test set. Used only by LLFF. + # If true, use all input images for training. + llff_use_all_images_for_training: bool = False + llff_use_all_images_for_testing: bool = False + use_tiffs: bool = False # If True, use 32-bit TIFFs. Used only by Blender. + compute_disp_metrics: bool = False # If True, load and compute disparity MSE. + compute_normal_metrics: bool = False # If True, load and compute normal MAE. + disable_multiscale_loss: bool = False # If True, disable multiscale loss. + randomized: bool = True # Use randomized stratified sampling. + near: float = 2. # Near plane distance. + far: float = 6. # Far plane distance. + exp_name: str = "test" # experiment name + data_dir: Optional[str] = "/SSD_DISK/datasets/360_v2/bicycle" # Input data directory. + vocab_tree_path: Optional[str] = None # Path to vocab tree for COLMAP. + render_chunk_size: int = 65536 # Chunk size for whole-image renderings. + num_showcase_images: int = 5 # The number of test-set images to showcase. + deterministic_showcase: bool = True # If True, showcase the same images. + vis_num_rays: int = 16 # The number of rays to visualize. + # Decimate images for tensorboard (ie, x[::d, ::d]) to conserve memory usage. + vis_decimate: int = 0 + + # Only used by train.py: + max_steps: int = 25000 # The number of optimization steps. + early_exit_steps: Optional[int] = None # Early stopping, for debugging. + checkpoint_every: int = 5000 # The number of steps to save a checkpoint. + resume_from_checkpoint: bool = True # whether to resume from checkpoint. + checkpoints_total_limit: int = 1 + gradient_scaling: bool = False # If True, scale gradients as in https://gradient-scaling.github.io/. + print_every: int = 100 # The number of steps between reports to tensorboard. + train_render_every: int = 500 # Steps between test set renders when training + data_loss_type: str = 'charb' # What kind of loss to use ('mse' or 'charb'). + charb_padding: float = 0.001 # The padding used for Charbonnier loss. + data_loss_mult: float = 1.0 # Mult for the finest data term in the loss. + data_coarse_loss_mult: float = 0. # Multiplier for the coarser data terms. + interlevel_loss_mult: float = 0.0 # Mult. for the loss on the proposal MLP. + anti_interlevel_loss_mult: float = 0.01 # Mult. for the loss on the proposal MLP. + pulse_width = [0.03, 0.003] # Mult. for the loss on the proposal MLP. + orientation_loss_mult: float = 0.0 # Multiplier on the orientation loss. + orientation_coarse_loss_mult: float = 0.0 # Coarser orientation loss weights. + # What that loss is imposed on, options are 'normals' or 'normals_pred'. + orientation_loss_target: str = 'normals_pred' + predicted_normal_loss_mult: float = 0.0 # Mult. on the predicted normal loss. + # Mult. on the coarser predicted normal loss. + predicted_normal_coarse_loss_mult: float = 0.0 + hash_decay_mults: float = 0.1 + + lr_init: float = 0.01 # The initial learning rate. + lr_final: float = 0.001 # The final learning rate. + lr_delay_steps: int = 5000 # The number of "warmup" learning steps. + lr_delay_mult: float = 1e-8 # How much sever the "warmup" should be. + adam_beta1: float = 0.9 # Adam's beta2 hyperparameter. + adam_beta2: float = 0.99 # Adam's beta2 hyperparameter. + adam_eps: float = 1e-15 # Adam's epsilon hyperparameter. + grad_max_norm: float = 0. # Gradient clipping magnitude, disabled if == 0. + grad_max_val: float = 0. # Gradient clipping value, disabled if == 0. + distortion_loss_mult: float = 0.005 # Multiplier on the distortion loss. + opacity_loss_mult: float = 0. # Multiplier on the distortion loss. + + # Only used by eval.py: + eval_only_once: bool = True # If True evaluate the model only once, ow loop. + eval_save_output: bool = True # If True save predicted images to disk. + eval_save_ray_data: bool = False # If True save individual ray traces. + eval_render_interval: int = 1 # The interval between images saved to disk. + eval_dataset_limit: int = np.iinfo(np.int32).max # Num test images to eval. + eval_quantize_metrics: bool = True # If True, run metrics on 8-bit images. + eval_crop_borders: int = 0 # Ignore c border pixels in eval (x[c:-c, c:-c]). + + # Only used by render.py + render_video_fps: int = 60 # Framerate in frames-per-second. + render_video_crf: int = 18 # Constant rate factor for ffmpeg video quality. + render_path_frames: int = 120 # Number of frames in render path. + z_variation: float = 0. # How much height variation in render path. + z_phase: float = 0. # Phase offset for height variation in render path. + render_dist_percentile: float = 0.5 # How much to trim from near/far planes. + render_dist_curve_fn: Callable[..., Any] = np.log # How depth is curved. + render_path_file: Optional[str] = None # Numpy render pose file to load. + render_resolution: Optional[Tuple[int, int]] = None # Render resolution, as + # (width, height). + render_focal: Optional[float] = None # Render focal length. + render_camtype: Optional[str] = None # 'perspective', 'fisheye', or 'pano'. + render_spherical: bool = False # Render spherical 360 panoramas. + render_save_async: bool = True # Save to CNS using a separate thread. + + render_spline_keyframes: Optional[str] = None # Text file containing names of + # images to be used as spline + # keyframes, OR directory + # containing those images. + render_spline_n_interp: int = 30 # Num. frames to interpolate per keyframe. + render_spline_degree: int = 5 # Polynomial degree of B-spline interpolation. + render_spline_smoothness: float = .03 # B-spline smoothing factor, 0 for + # exact interpolation of keyframes. + # Interpolate per-frame exposure value from spline keyframes. + render_spline_interpolate_exposure: bool = False + + # Flags for raw datasets. + rawnerf_mode: bool = False # Load raw images and train in raw color space. + exposure_percentile: float = 97. # Image percentile to expose as white. + num_border_pixels_to_mask: int = 0 # During training, discard N-pixel border + # around each input image. + apply_bayer_mask: bool = False # During training, apply Bayer mosaic mask. + autoexpose_renders: bool = False # During rendering, autoexpose each image. + # For raw test scenes, use affine raw-space color correction. + eval_raw_affine_cc: bool = False + + zero_glo: bool = False + + # marching cubes + valid_weight_thresh: float = 0.05 + isosurface_threshold: float = 20 + mesh_voxels: int = 512 ** 3 + visibility_resolution: int = 512 + mesh_radius: float = 1.0 # mesh radius * 2 = in contract space + mesh_max_radius: float = 10.0 # in world space + std_value: float = 0.0 # std of the sampled points + compute_visibility: bool = False + extract_visibility: bool = True + decimate_target: int = -1 + vertex_color: bool = True + vertex_projection: bool = True + + # tsdf + tsdf_radius: float = 2.0 + tsdf_resolution: int = 512 + truncation_margin: float = 5.0 + tsdf_max_radius: float = 10.0 # in world space + + +def define_common_flags(): + # Define the flags used by both train.py and eval.py + flags.DEFINE_string('mode', None, 'Required by GINXM, not used.') + flags.DEFINE_string('base_folder', None, 'Required by GINXM, not used.') + flags.DEFINE_multi_string('gin_bindings', None, 'Gin parameter bindings.') + flags.DEFINE_multi_string('gin_configs', None, 'Gin config files.') + + +def load_config(): + """Load the config, and optionally checkpoint it.""" + gin.parse_config_files_and_bindings( + flags.FLAGS.gin_configs, flags.FLAGS.gin_bindings, skip_unknown=True) + config = Config() + return config diff --git a/internal/coord.py b/internal/coord.py new file mode 100644 index 0000000000000000000000000000000000000000..0990f408bd6fc5996d73f2fb3021efb2d9b72ada --- /dev/null +++ b/internal/coord.py @@ -0,0 +1,225 @@ +from internal import math +from internal import utils +import numpy as np +import torch +# from torch.func import vmap, jacrev + + +def contract(x): + """Contracts points towards the origin (Eq 10 of arxiv.org/abs/2111.12077).""" + eps = torch.finfo(x.dtype).eps + # eps = 1e-3 + # Clamping to eps prevents non-finite gradients when x == 0. + x_mag_sq = torch.sum(x ** 2, dim=-1, keepdim=True).clamp_min(eps) + z = torch.where(x_mag_sq <= 1, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x) + return z + + +def inv_contract(z): + """The inverse of contract().""" + eps = torch.finfo(z.dtype).eps + + # Clamping to eps prevents non-finite gradients when z == 0. + z_mag_sq = torch.sum(z ** 2, dim=-1, keepdim=True).clamp_min(eps) + x = torch.where(z_mag_sq <= 1, z, z / (2 * torch.sqrt(z_mag_sq) - z_mag_sq).clamp_min(eps)) + return x + + +def inv_contract_np(z): + """The inverse of contract().""" + eps = np.finfo(z.dtype).eps + + # Clamping to eps prevents non-finite gradients when z == 0. + z_mag_sq = np.maximum(np.sum(z ** 2, axis=-1, keepdims=True), eps) + x = np.where(z_mag_sq <= 1, z, z / np.maximum(2 * np.sqrt(z_mag_sq) - z_mag_sq, eps)) + return x + + +def contract_tuple(x): + res = contract(x) + return res, res + + +def contract_mean_jacobi(x): + eps = torch.finfo(x.dtype).eps + # eps = 1e-3 + + # Clamping to eps prevents non-finite gradients when x == 0. + x_mag_sq = torch.sum(x ** 2, dim=-1, keepdim=True).clamp_min(eps) + x_mag_sqrt = torch.sqrt(x_mag_sq) + x_xT = math.matmul(x[..., None], x[..., None, :]) + mask = x_mag_sq <= 1 + z = torch.where(x_mag_sq <= 1, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x) + + eye = torch.broadcast_to(torch.eye(3, device=x.device), z.shape[:-1] + z.shape[-1:] * 2) + jacobi = (2 * x_xT * (1 - x_mag_sqrt[..., None]) + (2 * x_mag_sqrt[..., None] ** 3 - x_mag_sqrt[..., None] ** 2) * eye) / x_mag_sqrt[..., None] ** 4 + jacobi = torch.where(mask[..., None], eye, jacobi) + return z, jacobi + + +def contract_mean_std(x, std): + eps = torch.finfo(x.dtype).eps + # eps = 1e-3 + # Clamping to eps prevents non-finite gradients when x == 0. + x_mag_sq = torch.sum(x ** 2, dim=-1, keepdim=True).clamp_min(eps) + x_mag_sqrt = torch.sqrt(x_mag_sq) + mask = x_mag_sq <= 1 + z = torch.where(mask, x, ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq) * x) + # det_13 = ((1 / x_mag_sq) * ((2 / x_mag_sqrt - 1 / x_mag_sq) ** 2)) ** (1 / 3) + det_13 = (torch.pow(2 * x_mag_sqrt - 1, 1/3) / x_mag_sqrt) ** 2 + + std = torch.where(mask[..., 0], std, det_13[..., 0] * std) + return z, std + + +@torch.no_grad() +def track_linearize(fn, mean, std): + """Apply function `fn` to a set of means and covariances, ala a Kalman filter. + + We can analytically transform a Gaussian parameterized by `mean` and `cov` + with a function `fn` by linearizing `fn` around `mean`, and taking advantage + of the fact that Covar[Ax + y] = A(Covar[x])A^T (see + https://cs.nyu.edu/~roweis/notes/gaussid.pdf for details). + + Args: + fn: the function applied to the Gaussians parameterized by (mean, cov). + mean: a tensor of means, where the last axis is the dimension. + std: a tensor of covariances, where the last two axes are the dimensions. + + Returns: + fn_mean: the transformed means. + fn_cov: the transformed covariances. + """ + if fn == 'contract': + fn = contract_mean_jacobi + else: + raise NotImplementedError + + pre_shape = mean.shape[:-1] + mean = mean.reshape(-1, 3) + std = std.reshape(-1) + + # jvp_1, mean_1 = vmap(jacrev(contract_tuple, has_aux=True))(mean) + # std_1 = std * torch.linalg.det(jvp_1) ** (1 / mean.shape[-1]) + # + # mean_2, jvp_2 = fn(mean) + # std_2 = std * torch.linalg.det(jvp_2) ** (1 / mean.shape[-1]) + # + # mean_3, std_3 = contract_mean_std(mean, std) # calculate det explicitly by using eigenvalues + # torch.allclose(std_1, std_3, atol=1e-7) # True + # torch.allclose(mean_1, mean_3) # True + # import ipdb; ipdb.set_trace() + mean, std = contract_mean_std(mean, std) # calculate det explicitly by using eigenvalues + + mean = mean.reshape(*pre_shape, 3) + std = std.reshape(*pre_shape) + return mean, std + + +def power_transformation(x, lam): + """ + power transformation for Eq(4) in zip-nerf + """ + lam_1 = np.abs(lam - 1) + return lam_1 / lam * ((x / lam_1 + 1) ** lam - 1) + + +def inv_power_transformation(x, lam): + """ + inverse power transformation + """ + lam_1 = np.abs(lam - 1) + eps = torch.finfo(x.dtype).eps # may cause inf + # eps = 1e-3 + return ((x * lam / lam_1 + 1 + eps) ** (1 / lam) - 1) * lam_1 + + +def construct_ray_warps(fn, t_near, t_far, lam=None): + """Construct a bijection between metric distances and normalized distances. + + See the text around Equation 11 in https://arxiv.org/abs/2111.12077 for a + detailed explanation. + + Args: + fn: the function to ray distances. + t_near: a tensor of near-plane distances. + t_far: a tensor of far-plane distances. + lam: for lam in Eq(4) in zip-nerf + + Returns: + t_to_s: a function that maps distances to normalized distances in [0, 1]. + s_to_t: the inverse of t_to_s. + """ + if fn is None: + fn_fwd = lambda x: x + fn_inv = lambda x: x + elif fn == 'piecewise': + # Piecewise spacing combining identity and 1/x functions to allow t_near=0. + fn_fwd = lambda x: torch.where(x < 1, .5 * x, 1 - .5 / x) + fn_inv = lambda x: torch.where(x < .5, 2 * x, .5 / (1 - x)) + elif fn == 'power_transformation': + fn_fwd = lambda x: power_transformation(x * 2, lam=lam) + fn_inv = lambda y: inv_power_transformation(y, lam=lam) / 2 + else: + inv_mapping = { + 'reciprocal': torch.reciprocal, + 'log': torch.exp, + 'exp': torch.log, + 'sqrt': torch.square, + 'square': torch.sqrt, + } + fn_fwd = fn + fn_inv = inv_mapping[fn.__name__] + + s_near, s_far = [fn_fwd(x) for x in (t_near, t_far)] + t_to_s = lambda t: (fn_fwd(t) - s_near) / (s_far - s_near) + s_to_t = lambda s: fn_inv(s * s_far + (1 - s) * s_near) + return t_to_s, s_to_t + + +def expected_sin(mean, var): + """Compute the mean of sin(x), x ~ N(mean, var).""" + return torch.exp(-0.5 * var) * math.safe_sin(mean) # large var -> small value. + + +def integrated_pos_enc(mean, var, min_deg, max_deg): + """Encode `x` with sinusoids scaled by 2^[min_deg, max_deg). + + Args: + mean: tensor, the mean coordinates to be encoded + var: tensor, the variance of the coordinates to be encoded. + min_deg: int, the min degree of the encoding. + max_deg: int, the max degree of the encoding. + + Returns: + encoded: tensor, encoded variables. + """ + scales = 2 ** torch.arange(min_deg, max_deg, device=mean.device) + shape = mean.shape[:-1] + (-1,) + scaled_mean = (mean[..., None, :] * scales[:, None]).reshape(*shape) + scaled_var = (var[..., None, :] * scales[:, None] ** 2).reshape(*shape) + + return expected_sin( + torch.cat([scaled_mean, scaled_mean + 0.5 * torch.pi], dim=-1), + torch.cat([scaled_var] * 2, dim=-1)) + + +def lift_and_diagonalize(mean, cov, basis): + """Project `mean` and `cov` onto basis and diagonalize the projected cov.""" + fn_mean = math.matmul(mean, basis) + fn_cov_diag = torch.sum(basis * math.matmul(cov, basis), dim=-2) + return fn_mean, fn_cov_diag + + +def pos_enc(x, min_deg, max_deg, append_identity=True): + """The positional encoding used by the original NeRF paper.""" + scales = 2 ** torch.arange(min_deg, max_deg, device=x.device) + shape = x.shape[:-1] + (-1,) + scaled_x = (x[..., None, :] * scales[:, None]).reshape(*shape) + # Note that we're not using safe_sin, unlike IPE. + four_feat = torch.sin( + torch.cat([scaled_x, scaled_x + 0.5 * torch.pi], dim=-1)) + if append_identity: + return torch.cat([x] + [four_feat], dim=-1) + else: + return four_feat diff --git a/internal/datasets.py b/internal/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..c1082b85fada0b8bfd7b19f44a8568999b904421 --- /dev/null +++ b/internal/datasets.py @@ -0,0 +1,1016 @@ +import abc +import copy +import json +import os +import cv2 +from internal import camera_utils +from internal import configs +from internal import image as lib_image +from internal import raw_utils +from internal import utils +from collections import defaultdict +import numpy as np +import cv2 +from PIL import Image +import torch +from tqdm import tqdm +# This is ugly, but it works. +import sys + +sys.path.insert(0, 'internal/pycolmap') +sys.path.insert(0, 'internal/pycolmap/pycolmap') +import pycolmap + + +def load_dataset(split, train_dir, config: configs.Config): + """Loads a split of a dataset using the data_loader specified by `config`.""" + if config.multiscale: + dataset_dict = { + 'llff': MultiLLFF, + } + else: + dataset_dict = { + 'blender': Blender, + 'llff': LLFF, + 'tat_nerfpp': TanksAndTemplesNerfPP, + 'tat_fvs': TanksAndTemplesFVS, + 'dtu': DTU, + } + return dataset_dict[config.dataset_loader](split, train_dir, config) + + +class NeRFSceneManager(pycolmap.SceneManager): + """COLMAP pose loader. + + Minor NeRF-specific extension to the third_party Python COLMAP loader: + google3/third_party/py/pycolmap/scene_manager.py + """ + + def process(self): + """Applies NeRF-specific postprocessing to the loaded pose data. + + Returns: + a tuple [image_names, poses, pixtocam, distortion_params]. + image_names: contains the only the basename of the images. + poses: [N, 4, 4] array containing the camera to world matrices. + pixtocam: [N, 3, 3] array containing the camera to pixel space matrices. + distortion_params: mapping of distortion param name to distortion + parameters. Cameras share intrinsics. Valid keys are k1, k2, p1 and p2. + """ + + self.load_cameras() + self.load_images() + # self.load_points3D() # For now, we do not need the point cloud data. + + # Assume shared intrinsics between all cameras. + cam = self.cameras[1] + + # Extract focal lengths and principal point parameters. + fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy + pixtocam = np.linalg.inv(camera_utils.intrinsic_matrix(fx, fy, cx, cy)) + + # Extract extrinsic matrices in world-to-camera format. + imdata = self.images + w2c_mats = [] + bottom = np.array([0, 0, 0, 1]).reshape(1, 4) + for k in imdata: + im = imdata[k] + rot = im.R() + trans = im.tvec.reshape(3, 1) + w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0) + w2c_mats.append(w2c) + w2c_mats = np.stack(w2c_mats, axis=0) + + # Convert extrinsics to camera-to-world. + c2w_mats = np.linalg.inv(w2c_mats) + poses = c2w_mats[:, :3, :4] + + # Image names from COLMAP. No need for permuting the poses according to + # image names anymore. + names = [imdata[k].name for k in imdata] + + # Switch from COLMAP (right, down, fwd) to NeRF (right, up, back) frame. + poses = poses @ np.diag([1, -1, -1, 1]) + + # Get distortion parameters. + type_ = cam.camera_type + + if type_ == 0 or type_ == 'SIMPLE_PINHOLE': + params = None + camtype = camera_utils.ProjectionType.PERSPECTIVE + + elif type_ == 1 or type_ == 'PINHOLE': + params = None + camtype = camera_utils.ProjectionType.PERSPECTIVE + + if type_ == 2 or type_ == 'SIMPLE_RADIAL': + params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']} + params['k1'] = cam.k1 + camtype = camera_utils.ProjectionType.PERSPECTIVE + + elif type_ == 3 or type_ == 'RADIAL': + params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']} + params['k1'] = cam.k1 + params['k2'] = cam.k2 + camtype = camera_utils.ProjectionType.PERSPECTIVE + + elif type_ == 4 or type_ == 'OPENCV': + params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']} + params['k1'] = cam.k1 + params['k2'] = cam.k2 + params['p1'] = cam.p1 + params['p2'] = cam.p2 + camtype = camera_utils.ProjectionType.PERSPECTIVE + + elif type_ == 5 or type_ == 'OPENCV_FISHEYE': + params = {k: 0. for k in ['k1', 'k2', 'k3', 'k4']} + params['k1'] = cam.k1 + params['k2'] = cam.k2 + params['k3'] = cam.k3 + params['k4'] = cam.k4 + camtype = camera_utils.ProjectionType.FISHEYE + + return names, poses, pixtocam, params, camtype + + +def load_blender_posedata(data_dir, split=None): + """Load poses from `transforms.json` file, as used in Blender/NGP datasets.""" + suffix = '' if split is None else f'_{split}' + pose_file = os.path.join(data_dir, f'transforms{suffix}.json') + with utils.open_file(pose_file, 'r') as fp: + meta = json.load(fp) + names = [] + poses = [] + for _, frame in enumerate(meta['frames']): + filepath = os.path.join(data_dir, frame['file_path']) + if utils.file_exists(filepath): + names.append(frame['file_path'].split('/')[-1]) + poses.append(np.array(frame['transform_matrix'], dtype=np.float32)) + poses = np.stack(poses, axis=0) + + w = meta['w'] + h = meta['h'] + cx = meta['cx'] if 'cx' in meta else w / 2. + cy = meta['cy'] if 'cy' in meta else h / 2. + if 'fl_x' in meta: + fx = meta['fl_x'] + else: + fx = 0.5 * w / np.tan(0.5 * float(meta['camera_angle_x'])) + if 'fl_y' in meta: + fy = meta['fl_y'] + else: + fy = 0.5 * h / np.tan(0.5 * float(meta['camera_angle_y'])) + pixtocam = np.linalg.inv(camera_utils.intrinsic_matrix(fx, fy, cx, cy)) + coeffs = ['k1', 'k2', 'p1', 'p2'] + if not any([c in meta for c in coeffs]): + params = None + else: + params = {c: (meta[c] if c in meta else 0.) for c in coeffs} + camtype = camera_utils.ProjectionType.PERSPECTIVE + return names, poses, pixtocam, params, camtype + + +class Dataset(torch.utils.data.Dataset): + """Dataset Base Class. + + Base class for a NeRF dataset. Creates batches of ray and color data used for + training or rendering a NeRF model. + + Each subclass is responsible for loading images and camera poses from disk by + implementing the _load_renderings() method. This data is used to generate + train and test batches of ray + color data for feeding through the NeRF model. + The ray parameters are calculated in _generate_rays(). + + The public interface mimics the behavior of a standard machine learning + pipeline dataset provider that can provide infinite batches of data to the + training/testing pipelines without exposing any details of how the batches are + loaded/created or how this is parallelized. Therefore, the initializer runs + all setup, including data loading from disk using _load_renderings(), and + begins the thread using its parent start() method. After the initializer + returns, the caller can request batches of data straight away. + + The internal self._queue is initialized as queue.Queue(3), so the infinite + loop in run() will block on the call self._queue.put(self._next_fn()) once + there are 3 elements. The main thread training job runs in a loop that pops 1 + element at a time off the front of the queue. The Dataset thread's run() loop + will populate the queue with 3 elements, then wait until a batch has been + removed and push one more onto the end. + + This repeats indefinitely until the main thread's training loop completes + (typically hundreds of thousands of iterations), then the main thread will + exit and the Dataset thread will automatically be killed since it is a daemon. + + Attributes: + alphas: np.ndarray, optional array of alpha channel data. + cameras: tuple summarizing all camera extrinsic/intrinsic/distortion params. + camtoworlds: np.ndarray, a list of extrinsic camera pose matrices. + camtype: camera_utils.ProjectionType, fisheye or perspective camera. + data_dir: str, location of the dataset on disk. + disp_images: np.ndarray, optional array of disparity (inverse depth) data. + distortion_params: dict, the camera distortion model parameters. + exposures: optional per-image exposure value (shutter * ISO / 1000). + far: float, far plane value for rays. + focal: float, focal length from camera intrinsics. + height: int, height of images. + images: np.ndarray, array of RGB image data. + metadata: dict, optional metadata for raw datasets. + near: float, near plane value for rays. + normal_images: np.ndarray, optional array of surface normal vector data. + pixtocams: np.ndarray, one or a list of inverse intrinsic camera matrices. + pixtocam_ndc: np.ndarray, the inverse intrinsic matrix used for NDC space. + poses: np.ndarray, optional array of auxiliary camera pose data. + rays: utils.Rays, ray data for every pixel in the dataset. + render_exposures: optional list of exposure values for the render path. + render_path: bool, indicates if a smooth camera path should be generated. + size: int, number of images in the dataset. + split: str, indicates if this is a "train" or "test" dataset. + width: int, width of images. + """ + + def __init__(self, + split: str, + data_dir: str, + config: configs.Config): + super().__init__() + + # Initialize attributes + self._patch_size = max(config.patch_size, 1) + self._batch_size = config.batch_size // config.world_size + if self._patch_size ** 2 > self._batch_size: + raise ValueError(f'Patch size {self._patch_size}^2 too large for ' + + f'per-process batch size {self._batch_size}') + self._batching = utils.BatchingMethod(config.batching) + self._use_tiffs = config.use_tiffs + self._load_disps = config.compute_disp_metrics + self._load_normals = config.compute_normal_metrics + self._num_border_pixels_to_mask = config.num_border_pixels_to_mask + self._apply_bayer_mask = config.apply_bayer_mask + self._render_spherical = False + + self.config = config + self.global_rank = config.global_rank + self.world_size = config.world_size + self.split = utils.DataSplit(split) + self.data_dir = data_dir + self.near = config.near + self.far = config.far + self.render_path = config.render_path + self.distortion_params = None + self.disp_images = None + self.normal_images = None + self.alphas = None + self.poses = None + self.pixtocam_ndc = None + self.metadata = None + self.camtype = camera_utils.ProjectionType.PERSPECTIVE + self.exposures = None + self.render_exposures = None + + # Providing type comments for these attributes, they must be correctly + # initialized by _load_renderings() (see docstring) in any subclass. + self.images: np.ndarray = None + self.camtoworlds: np.ndarray = None + self.pixtocams: np.ndarray = None + self.height: int = None + self.width: int = None + + # Load data from disk using provided config parameters. + self._load_renderings(config) + + if self.render_path: + if config.render_path_file is not None: + with utils.open_file(config.render_path_file, 'rb') as fp: + render_poses = np.load(fp) + self.camtoworlds = render_poses + if config.render_resolution is not None: + self.width, self.height = config.render_resolution + if config.render_focal is not None: + self.focal = config.render_focal + if config.render_camtype is not None: + if config.render_camtype == 'pano': + self._render_spherical = True + else: + self.camtype = camera_utils.ProjectionType(config.render_camtype) + + self.distortion_params = None + self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width, + self.height) + + self._n_examples = self.camtoworlds.shape[0] + + self.cameras = (self.pixtocams, + self.camtoworlds, + self.distortion_params, + self.pixtocam_ndc) + + # Seed the queue with one batch to avoid race condition. + if self.split == utils.DataSplit.TRAIN and not config.compute_visibility: + self._next_fn = self._next_train + else: + self._next_fn = self._next_test + + @property + def size(self): + return self._n_examples + + def __len__(self): + if self.split == utils.DataSplit.TRAIN and not self.config.compute_visibility: + return 1000 + else: + return self._n_examples + + @abc.abstractmethod + def _load_renderings(self, config): + """Load images and poses from disk. + + Args: + config: utils.Config, user-specified config parameters. + In inherited classes, this method must set the following public attributes: + images: [N, height, width, 3] array for RGB images. + disp_images: [N, height, width] array for depth data (optional). + normal_images: [N, height, width, 3] array for normals (optional). + camtoworlds: [N, 3, 4] array of extrinsic pose matrices. + poses: [..., 3, 4] array of auxiliary pose data (optional). + pixtocams: [N, 3, 4] array of inverse intrinsic matrices. + distortion_params: dict, camera lens distortion model parameters. + height: int, height of images. + width: int, width of images. + focal: float, focal length to use for ideal pinhole rendering. + """ + + def _make_ray_batch(self, + pix_x_int, + pix_y_int, + cam_idx, + lossmult=None + ): + """Creates ray data batch from pixel coordinates and camera indices. + + All arguments must have broadcastable shapes. If the arguments together + broadcast to a shape [a, b, c, ..., z] then the returned utils.Rays object + will have array attributes with shape [a, b, c, ..., z, N], where N=3 for + 3D vectors and N=1 for per-ray scalar attributes. + + Args: + pix_x_int: int array, x coordinates of image pixels. + pix_y_int: int array, y coordinates of image pixels. + cam_idx: int or int array, camera indices. + lossmult: float array, weight to apply to each ray when computing loss fn. + + Returns: + A dict mapping from strings utils.Rays or arrays of image data. + This is the batch provided for one NeRF train or test iteration. + """ + + broadcast_scalar = lambda x: np.broadcast_to(x, pix_x_int.shape)[..., None] + ray_kwargs = { + 'lossmult': broadcast_scalar(1.) if lossmult is None else lossmult, + 'near': broadcast_scalar(self.near), + 'far': broadcast_scalar(self.far), + 'cam_idx': broadcast_scalar(cam_idx), + } + # Collect per-camera information needed for each ray. + if self.metadata is not None: + # Exposure index and relative shutter speed, needed for RawNeRF. + for key in ['exposure_idx', 'exposure_values']: + idx = 0 if self.render_path else cam_idx + ray_kwargs[key] = broadcast_scalar(self.metadata[key][idx]) + if self.exposures is not None: + idx = 0 if self.render_path else cam_idx + ray_kwargs['exposure_values'] = broadcast_scalar(self.exposures[idx]) + if self.render_path and self.render_exposures is not None: + ray_kwargs['exposure_values'] = broadcast_scalar( + self.render_exposures[cam_idx]) + + pixels = dict(pix_x_int=pix_x_int, pix_y_int=pix_y_int, **ray_kwargs) + + # Slow path, do ray computation using numpy (on CPU). + batch = camera_utils.cast_ray_batch(self.cameras, pixels, self.camtype) + batch['cam_dirs'] = -self.camtoworlds[ray_kwargs['cam_idx'][..., 0]][..., :3, 2] + + # import trimesh + # pts = batch['origins'][..., None, :] + batch['directions'][..., None, :] * np.linspace(0, 1, 5)[:, None] + # trimesh.Trimesh(vertices=pts.reshape(-1, 3)).export("test.ply", "ply") + # + # pts = batch['origins'][0, 0, None, :] - self.camtoworlds[cam_idx][:, 2] * np.linspace(0, 1, 100)[:, None] + # trimesh.Trimesh(vertices=pts.reshape(-1, 3)).export("test2.ply", "ply") + + if not self.render_path: + batch['rgb'] = self.images[cam_idx, pix_y_int, pix_x_int] + if self._load_disps: + batch['disps'] = self.disp_images[cam_idx, pix_y_int, pix_x_int] + if self._load_normals: + batch['normals'] = self.normal_images[cam_idx, pix_y_int, pix_x_int] + batch['alphas'] = self.alphas[cam_idx, pix_y_int, pix_x_int] + return {k: torch.from_numpy(v.copy()).float() if v is not None else None for k, v in batch.items()} + + def _next_train(self, item): + """Sample next training batch (random rays).""" + # We assume all images in the dataset are the same resolution, so we can use + # the same width/height for sampling all pixels coordinates in the batch. + # Batch/patch sampling parameters. + num_patches = self._batch_size // self._patch_size ** 2 + lower_border = self._num_border_pixels_to_mask + upper_border = self._num_border_pixels_to_mask + self._patch_size - 1 + # Random pixel patch x-coordinates. + pix_x_int = np.random.randint(lower_border, self.width - upper_border, + (num_patches, 1, 1)) + # Random pixel patch y-coordinates. + pix_y_int = np.random.randint(lower_border, self.height - upper_border, + (num_patches, 1, 1)) + # Add patch coordinate offsets. + # Shape will broadcast to (num_patches, _patch_size, _patch_size). + patch_dx_int, patch_dy_int = camera_utils.pixel_coordinates( + self._patch_size, self._patch_size) + pix_x_int = pix_x_int + patch_dx_int + pix_y_int = pix_y_int + patch_dy_int + # Random camera indices. + if self._batching == utils.BatchingMethod.ALL_IMAGES: + cam_idx = np.random.randint(0, self._n_examples, (num_patches, 1, 1)) + else: + cam_idx = np.random.randint(0, self._n_examples, (1,)) + + if self._apply_bayer_mask: + # Compute the Bayer mosaic mask for each pixel in the batch. + lossmult = raw_utils.pixels_to_bayer_mask(pix_x_int, pix_y_int) + else: + lossmult = None + + return self._make_ray_batch(pix_x_int, pix_y_int, cam_idx, + lossmult=lossmult) + + def generate_ray_batch(self, cam_idx: int): + """Generate ray batch for a specified camera in the dataset.""" + if self._render_spherical: + camtoworld = self.camtoworlds[cam_idx] + rays = camera_utils.cast_spherical_rays( + camtoworld, self.height, self.width, self.near, self.far) + return rays + else: + # Generate rays for all pixels in the image. + pix_x_int, pix_y_int = camera_utils.pixel_coordinates( + self.width, self.height) + return self._make_ray_batch(pix_x_int, pix_y_int, cam_idx) + + def _next_test(self, item): + """Sample next test batch (one full image).""" + return self.generate_ray_batch(item) + + def collate_fn(self, item): + return self._next_fn(item[0]) + + def __getitem__(self, item): + return self._next_fn(item) + + +class Blender(Dataset): + """Blender Dataset.""" + + def _load_renderings(self, config): + """Load images from disk.""" + if config.render_path: + raise ValueError('render_path cannot be used for the blender dataset.') + pose_file = os.path.join(self.data_dir, f'transforms_{self.split.value}.json') + with utils.open_file(pose_file, 'r') as fp: + meta = json.load(fp) + images = [] + disp_images = [] + normal_images = [] + cams = [] + for idx, frame in enumerate(tqdm(meta['frames'], desc='Loading Blender dataset', disable=self.global_rank != 0, leave=False)): + fprefix = os.path.join(self.data_dir, frame['file_path']) + + def get_img(f, fprefix=fprefix): + image = utils.load_img(fprefix + f) + if config.factor > 1: + image = lib_image.downsample(image, config.factor) + return image + + if self._use_tiffs: + channels = [get_img(f'_{ch}.tiff') for ch in ['R', 'G', 'B', 'A']] + # Convert image to sRGB color space. + image = lib_image.linear_to_srgb_np(np.stack(channels, axis=-1)) + else: + image = get_img('.png') / 255. + images.append(image) + + if self._load_disps: + disp_image = get_img('_disp.tiff') + disp_images.append(disp_image) + if self._load_normals: + normal_image = get_img('_normal.png')[..., :3] * 2. / 255. - 1. + normal_images.append(normal_image) + + cams.append(np.array(frame['transform_matrix'], dtype=np.float32)) + + self.images = np.stack(images, axis=0) + if self._load_disps: + self.disp_images = np.stack(disp_images, axis=0) + if self._load_normals: + self.normal_images = np.stack(normal_images, axis=0) + self.alphas = self.images[..., -1] + + rgb, alpha = self.images[..., :3], self.images[..., -1:] + self.images = rgb * alpha + (1. - alpha) # Use a white background. + self.height, self.width = self.images.shape[1:3] + self.camtoworlds = np.stack(cams, axis=0) + self.focal = .5 * self.width / np.tan(.5 * float(meta['camera_angle_x'])) + self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width, + self.height) + + +class LLFF(Dataset): + """LLFF Dataset.""" + + def _load_renderings(self, config): + """Load images from disk.""" + # Set up scaling factor. + image_dir_suffix = '' + # Use downsampling factor (unless loading training split for raw dataset, + # we train raw at full resolution because of the Bayer mosaic pattern). + if config.factor > 0 and not (config.rawnerf_mode and + self.split == utils.DataSplit.TRAIN): + image_dir_suffix = f'_{config.factor}' + factor = config.factor + else: + factor = 1 + + # Copy COLMAP data to local disk for faster loading. + colmap_dir = os.path.join(self.data_dir, 'sparse/0/') + + # Load poses. + if utils.file_exists(colmap_dir): + pose_data = NeRFSceneManager(colmap_dir).process() + else: + # # Attempt to load Blender/NGP format if COLMAP data not present. + # pose_data = load_blender_posedata(self.data_dir) + raise ValueError('COLMAP data not found.') + image_names, poses, pixtocam, distortion_params, camtype = pose_data + + # Previous NeRF results were generated with images sorted by filename, + # use this flag to ensure metrics are reported on the same test set. + inds = np.argsort(image_names) + image_names = [image_names[i] for i in inds] + poses = poses[inds] + + # Load bounds if possible (only used in forward facing scenes). + posefile = os.path.join(self.data_dir, 'poses_bounds.npy') + if utils.file_exists(posefile): + with utils.open_file(posefile, 'rb') as fp: + poses_arr = np.load(fp) + bounds = poses_arr[:, -2:] + else: + bounds = np.array([0.01, 1.]) + self.colmap_to_world_transform = np.eye(4) + + # Scale the inverse intrinsics matrix by the image downsampling factor. + pixtocam = pixtocam @ np.diag([factor, factor, 1.]) + self.pixtocams = pixtocam.astype(np.float32) + self.focal = 1. / self.pixtocams[0, 0] + self.distortion_params = distortion_params + self.camtype = camtype + + # Separate out 360 versus forward facing scenes. + if config.forward_facing: + # Set the projective matrix defining the NDC transformation. + self.pixtocam_ndc = self.pixtocams.reshape(-1, 3, 3)[0] + # Rescale according to a default bd factor. + scale = 1. / (bounds.min() * .75) + poses[:, :3, 3] *= scale + self.colmap_to_world_transform = np.diag([scale] * 3 + [1]) + bounds *= scale + # Recenter poses. + poses, transform = camera_utils.recenter_poses(poses) + self.colmap_to_world_transform = ( + transform @ self.colmap_to_world_transform) + # Forward-facing spiral render path. + self.render_poses = camera_utils.generate_spiral_path( + poses, bounds, n_frames=config.render_path_frames) + else: + # Rotate/scale poses to align ground with xy plane and fit to unit cube. + poses, transform = camera_utils.transform_poses_pca(poses) + self.colmap_to_world_transform = transform + if config.render_spline_keyframes is not None: + rets = camera_utils.create_render_spline_path(config, image_names, + poses, self.exposures) + self.spline_indices, self.render_poses, self.render_exposures = rets + else: + # Automatically generated inward-facing elliptical render path. + self.render_poses = camera_utils.generate_ellipse_path( + poses, + n_frames=config.render_path_frames, + z_variation=config.z_variation, + z_phase=config.z_phase) + + # Select the split. + all_indices = np.arange(len(image_names)) + if config.llff_use_all_images_for_training: + train_indices = all_indices + else: + train_indices = all_indices % config.llffhold != 0 + if config.llff_use_all_images_for_testing: + test_indices = all_indices + else: + test_indices = all_indices % config.llffhold == 0 + split_indices = { + utils.DataSplit.TEST: all_indices[test_indices], + utils.DataSplit.TRAIN: all_indices[train_indices], + } + indices = split_indices[self.split] + image_names = [image_names[i] for i in indices] + poses = poses[indices] + # if self.split == utils.DataSplit.TRAIN: + # # load different training data on different rank + # local_indices = [i for i in range(len(image_names)) if (i + self.global_rank) % self.world_size == 0] + # image_names = [image_names[i] for i in local_indices] + # poses = poses[local_indices] + # indices = local_indices + + raw_testscene = False + if config.rawnerf_mode: + # Load raw images and metadata. + images, metadata, raw_testscene = raw_utils.load_raw_dataset( + self.split, + self.data_dir, + image_names, + config.exposure_percentile, + factor) + self.metadata = metadata + + else: + # Load images. + colmap_image_dir = os.path.join(self.data_dir, 'images') + image_dir = os.path.join(self.data_dir, 'images' + image_dir_suffix) + for d in [image_dir, colmap_image_dir]: + if not utils.file_exists(d): + raise ValueError(f'Image folder {d} does not exist.') + # Downsampled images may have different names vs images used for COLMAP, + # so we need to map between the two sorted lists of files. + colmap_files = sorted(utils.listdir(colmap_image_dir)) + image_files = sorted(utils.listdir(image_dir)) + colmap_to_image = dict(zip(colmap_files, image_files)) + image_paths = [os.path.join(image_dir, colmap_to_image[f]) + for f in image_names] + images = [utils.load_img(x) for x in tqdm(image_paths, desc='Loading LLFF dataset', disable=self.global_rank != 0, leave=False)] + images = np.stack(images, axis=0) / 255. + + # EXIF data is usually only present in the original JPEG images. + jpeg_paths = [os.path.join(colmap_image_dir, f) for f in image_names] + exifs = [utils.load_exif(x) for x in jpeg_paths] + self.exifs = exifs + if 'ExposureTime' in exifs[0] and 'ISOSpeedRatings' in exifs[0]: + gather_exif_value = lambda k: np.array([float(x[k]) for x in exifs]) + shutters = gather_exif_value('ExposureTime') + isos = gather_exif_value('ISOSpeedRatings') + self.exposures = shutters * isos / 1000. + + if raw_testscene: + # For raw testscene, the first image sent to COLMAP has the same pose as + # the ground truth test image. The remaining images form the training set. + raw_testscene_poses = { + utils.DataSplit.TEST: poses[:1], + utils.DataSplit.TRAIN: poses[1:], + } + poses = raw_testscene_poses[self.split] + + self.poses = poses + self.images = images + self.camtoworlds = self.render_poses if config.render_path else poses + self.height, self.width = images.shape[1:3] + + +class TanksAndTemplesNerfPP(Dataset): + """Subset of Tanks and Temples Dataset as processed by NeRF++.""" + + def _load_renderings(self, config): + """Load images from disk.""" + if config.render_path: + split_str = 'camera_path' + else: + split_str = self.split.value + + basedir = os.path.join(self.data_dir, split_str) + + # TODO: need to rewrite this to put different data on different rank + def load_files(dirname, load_fn, shape=None): + files = [ + os.path.join(basedir, dirname, f) + for f in sorted(utils.listdir(os.path.join(basedir, dirname))) + ] + mats = np.array([load_fn(utils.open_file(f, 'rb')) for f in files]) + if shape is not None: + mats = mats.reshape(mats.shape[:1] + shape) + return mats + + poses = load_files('pose', np.loadtxt, (4, 4)) + # Flip Y and Z axes to get correct coordinate frame. + poses = np.matmul(poses, np.diag(np.array([1, -1, -1, 1]))) + + # For now, ignore all but the first focal length in intrinsics + intrinsics = load_files('intrinsics', np.loadtxt, (4, 4)) + + if not config.render_path: + images = load_files('rgb', lambda f: np.array(Image.open(f))) / 255. + self.images = images + self.height, self.width = self.images.shape[1:3] + + else: + # Hack to grab the image resolution from a test image + d = os.path.join(self.data_dir, 'test', 'rgb') + f = os.path.join(d, sorted(utils.listdir(d))[0]) + shape = utils.load_img(f).shape + self.height, self.width = shape[:2] + self.images = None + + self.camtoworlds = poses + self.focal = intrinsics[0, 0, 0] + self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width, + self.height) + + +class TanksAndTemplesFVS(Dataset): + """Subset of Tanks and Temples Dataset as processed by Free View Synthesis.""" + + def _load_renderings(self, config): + """Load images from disk.""" + render_only = config.render_path and self.split == utils.DataSplit.TEST + + basedir = os.path.join(self.data_dir, 'dense') + sizes = [f for f in sorted(utils.listdir(basedir)) if f.startswith('ibr3d')] + sizes = sizes[::-1] + + if config.factor >= len(sizes): + raise ValueError(f'Factor {config.factor} larger than {len(sizes)}') + + basedir = os.path.join(basedir, sizes[config.factor]) + open_fn = lambda f: utils.open_file(os.path.join(basedir, f), 'rb') + + files = [f for f in sorted(utils.listdir(basedir)) if f.startswith('im_')] + if render_only: + files = files[:1] + images = np.array([np.array(Image.open(open_fn(f))) for f in files]) / 255. + + names = ['Ks', 'Rs', 'ts'] + intrinsics, rot, trans = (np.load(open_fn(f'{n}.npy')) for n in names) + + # Convert poses from colmap world-to-cam into our cam-to-world. + w2c = np.concatenate([rot, trans[..., None]], axis=-1) + c2w_colmap = np.linalg.inv(camera_utils.pad_poses(w2c))[:, :3, :4] + c2w = c2w_colmap @ np.diag(np.array([1, -1, -1, 1])) + + # Reorient poses so z-axis is up + poses, _ = camera_utils.transform_poses_pca(c2w) + self.poses = poses + + self.images = images + self.height, self.width = self.images.shape[1:3] + self.camtoworlds = poses + # For now, ignore all but the first focal length in intrinsics + self.focal = intrinsics[0, 0, 0] + self.pixtocams = camera_utils.get_pixtocam(self.focal, self.width, + self.height) + + if render_only: + render_path = camera_utils.generate_ellipse_path( + poses, + config.render_path_frames, + z_variation=config.z_variation, + z_phase=config.z_phase) + self.images = None + self.camtoworlds = render_path + self.render_poses = render_path + else: + # Select the split. + all_indices = np.arange(images.shape[0]) + indices = { + utils.DataSplit.TEST: + all_indices[all_indices % config.llffhold == 0], + utils.DataSplit.TRAIN: + all_indices[all_indices % config.llffhold != 0], + }[self.split] + + self.images = self.images[indices] + self.camtoworlds = self.camtoworlds[indices] + + +class DTU(Dataset): + """DTU Dataset.""" + + def _load_renderings(self, config): + """Load images from disk.""" + if config.render_path: + raise ValueError('render_path cannot be used for the DTU dataset.') + + images = [] + pixtocams = [] + camtoworlds = [] + + # Find out whether the particular scan has 49 or 65 images. + n_images = len(utils.listdir(self.data_dir)) // 8 + + # Loop over all images. + for i in range(1, n_images + 1): + # Set light condition string accordingly. + if config.dtu_light_cond < 7: + light_str = f'{config.dtu_light_cond}_r' + ('5000' + if i < 50 else '7000') + else: + light_str = 'max' + + # Load image. + fname = os.path.join(self.data_dir, f'rect_{i:03d}_{light_str}.png') + image = utils.load_img(fname) / 255. + if config.factor > 1: + image = lib_image.downsample(image, config.factor) + images.append(image) + + # Load projection matrix from file. + fname = os.path.join(self.data_dir, f'../../cal18/pos_{i:03d}.txt') + with utils.open_file(fname, 'rb') as f: + projection = np.loadtxt(f, dtype=np.float32) + + # Decompose projection matrix into pose and camera matrix. + camera_mat, rot_mat, t = cv2.decomposeProjectionMatrix(projection)[:3] + camera_mat = camera_mat / camera_mat[2, 2] + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = rot_mat.transpose() + pose[:3, 3] = (t[:3] / t[3])[:, 0] + pose = pose[:3] + camtoworlds.append(pose) + + if config.factor > 0: + # Scale camera matrix according to downsampling factor. + camera_mat = np.diag([1. / config.factor, 1. / config.factor, 1. + ]).astype(np.float32) @ camera_mat + pixtocams.append(np.linalg.inv(camera_mat)) + + pixtocams = np.stack(pixtocams) + camtoworlds = np.stack(camtoworlds) + images = np.stack(images) + + def rescale_poses(poses): + """Rescales camera poses according to maximum x/y/z value.""" + s = np.max(np.abs(poses[:, :3, -1])) + out = np.copy(poses) + out[:, :3, -1] /= s + return out + + # Center and scale poses. + camtoworlds, _ = camera_utils.recenter_poses(camtoworlds) + camtoworlds = rescale_poses(camtoworlds) + # Flip y and z axes to get poses in OpenGL coordinate system. + camtoworlds = camtoworlds @ np.diag([1., -1., -1., 1.]).astype(np.float32) + + all_indices = np.arange(images.shape[0]) + split_indices = { + utils.DataSplit.TEST: all_indices[all_indices % config.dtuhold == 0], + utils.DataSplit.TRAIN: all_indices[all_indices % config.dtuhold != 0], + } + indices = split_indices[self.split] + + self.images = images[indices] + self.height, self.width = images.shape[1:3] + self.camtoworlds = camtoworlds[indices] + self.pixtocams = pixtocams[indices] + + +class Multicam(Dataset): + def __init__(self, + split: str, + data_dir: str, + config: configs.Config): + super().__init__(split, data_dir, config) + + self.multiscale_levels = config.multiscale_levels + + images, camtoworlds, pixtocams, pixtocam_ndc = \ + self.images, self.camtoworlds, self.pixtocams, self.pixtocam_ndc + self.heights, self.widths, self.focals, self.images, self.camtoworlds, self.pixtocams, self.lossmults = [], [], [], [], [], [], [] + if pixtocam_ndc is not None: + self.pixtocam_ndc = [] + else: + self.pixtocam_ndc = None + + for i in range(self._n_examples): + for j in range(self.multiscale_levels): + self.heights.append(self.height // 2 ** j) + self.widths.append(self.width // 2 ** j) + + self.pixtocams.append(pixtocams @ np.diag([self.height / self.heights[-1], + self.width / self.widths[-1], + 1.])) + self.focals.append(1. / self.pixtocams[-1][0, 0]) + if config.forward_facing: + # Set the projective matrix defining the NDC transformation. + self.pixtocam_ndc.append(pixtocams.reshape(3, 3)) + + self.camtoworlds.append(camtoworlds[i]) + self.lossmults.append(2. ** j) + self.images.append(self.down2(images[i], (self.heights[-1], self.widths[-1]))) + self.pixtocams = np.stack(self.pixtocams) + self.camtoworlds = np.stack(self.camtoworlds) + self.cameras = (self.pixtocams, + self.camtoworlds, + self.distortion_params, + np.stack(self.pixtocam_ndc) if self.pixtocam_ndc is not None else None) + self._generate_rays() + + if self.split == utils.DataSplit.TRAIN: + # Always flatten out the height x width dimensions + def flatten(x): + if x[0] is not None: + x = [y.reshape([-1, y.shape[-1]]) for y in x] + if self._batching == utils.BatchingMethod.ALL_IMAGES: + # If global batching, also concatenate all data into one list + x = np.concatenate(x, axis=0) + return x + else: + return None + + self.batches = {k: flatten(v) for k, v in self.batches.items()} + self._n_examples = len(self.camtoworlds) + + # Seed the queue with one batch to avoid race condition. + if self.split == utils.DataSplit.TRAIN: + self._next_fn = self._next_train + else: + self._next_fn = self._next_test + + def _generate_rays(self): + if self.global_rank == 0: + tbar = tqdm(range(len(self.camtoworlds)), desc='Generating rays', leave=False) + else: + tbar = range(len(self.camtoworlds)) + + self.batches = defaultdict(list) + for cam_idx in tbar: + pix_x_int, pix_y_int = camera_utils.pixel_coordinates( + self.widths[cam_idx], self.heights[cam_idx]) + broadcast_scalar = lambda x: np.broadcast_to(x, pix_x_int.shape)[..., None] + ray_kwargs = { + 'lossmult': broadcast_scalar(self.lossmults[cam_idx]), + 'near': broadcast_scalar(self.near), + 'far': broadcast_scalar(self.far), + 'cam_idx': broadcast_scalar(cam_idx), + } + + pixels = dict(pix_x_int=pix_x_int, pix_y_int=pix_y_int, **ray_kwargs) + + batch = camera_utils.cast_ray_batch(self.cameras, pixels, self.camtype) + if not self.render_path: + batch['rgb'] = self.images[cam_idx] + if self._load_disps: + batch['disps'] = self.disp_images[cam_idx, pix_y_int, pix_x_int] + if self._load_normals: + batch['normals'] = self.normal_images[cam_idx, pix_y_int, pix_x_int] + batch['alphas'] = self.alphas[cam_idx, pix_y_int, pix_x_int] + for k, v in batch.items(): + self.batches[k].append(v) + + def _next_train(self, item): + """Sample next training batch (random rays).""" + # We assume all images in the dataset are the same resolution, so we can use + # the same width/height for sampling all pixels coordinates in the batch. + # Batch/patch sampling parameters. + num_patches = self._batch_size // self._patch_size ** 2 + # Random camera indices. + if self._batching == utils.BatchingMethod.ALL_IMAGES: + ray_indices = np.random.randint(0, self.batches['origins'].shape[0], (num_patches, 1, 1)) + batch = {k: v[ray_indices] if v is not None else None for k, v in self.batches.items()} + else: + image_index = np.random.randint(0, self._n_examples, ()) + ray_indices = np.random.randint(0, self.batches['origins'][image_index].shape[0], (num_patches,)) + batch = {k: v[image_index][ray_indices] if v is not None else None for k, v in self.batches.items()} + batch['cam_dirs'] = -self.camtoworlds[batch['cam_idx'][..., 0]][..., 2] + return {k: torch.from_numpy(v.copy()).float() if v is not None else None for k, v in batch.items()} + + def _next_test(self, item): + """Sample next test batch (one full image).""" + batch = {k: v[item] for k, v in self.batches.items()} + batch['cam_dirs'] = -self.camtoworlds[batch['cam_idx'][..., 0]][..., 2] + return {k: torch.from_numpy(v.copy()).float() if v is not None else None for k, v in batch.items()} + + @staticmethod + def down2(img, sh): + return cv2.resize(img, sh[::-1], interpolation=cv2.INTER_CUBIC) + + +class MultiLLFF(Multicam, LLFF): + pass + + +if __name__ == '__main__': + from internal import configs + import accelerate + + config = configs.Config() + accelerator = accelerate.Accelerator() + config.world_size = accelerator.num_processes + config.global_rank = accelerator.process_index + config.factor = 8 + dataset = LLFF('test', '/SSD_DISK/datasets/360_v2/bicycle', config) + print(len(dataset)) + for _ in tqdm(dataset): + pass + print('done') + # print(accelerator.process_index) diff --git a/internal/geopoly.py b/internal/geopoly.py new file mode 100644 index 0000000000000000000000000000000000000000..69d921b2ac13789d08b7ae551fc6105252f112ca --- /dev/null +++ b/internal/geopoly.py @@ -0,0 +1,108 @@ +import itertools +import numpy as np + + +def compute_sq_dist(mat0, mat1=None): + """Compute the squared Euclidean distance between all pairs of columns.""" + if mat1 is None: + mat1 = mat0 + # Use the fact that ||x - y||^2 == ||x||^2 + ||y||^2 - 2 x^T y. + sq_norm0 = np.sum(mat0 ** 2, 0) + sq_norm1 = np.sum(mat1 ** 2, 0) + sq_dist = sq_norm0[:, None] + sq_norm1[None, :] - 2 * mat0.T @ mat1 + sq_dist = np.maximum(0, sq_dist) # Negative values must be numerical errors. + return sq_dist + + +def compute_tesselation_weights(v): + """Tesselate the vertices of a triangle by a factor of `v`.""" + if v < 1: + raise ValueError(f'v {v} must be >= 1') + int_weights = [] + for i in range(v + 1): + for j in range(v + 1 - i): + int_weights.append((i, j, v - (i + j))) + int_weights = np.array(int_weights) + weights = int_weights / v # Barycentric weights. + return weights + + +def tesselate_geodesic(base_verts, base_faces, v, eps=1e-4): + """Tesselate the vertices of a geodesic polyhedron. + + Args: + base_verts: tensor of floats, the vertex coordinates of the geodesic. + base_faces: tensor of ints, the indices of the vertices of base_verts that + constitute eachface of the polyhedra. + v: int, the factor of the tesselation (v==1 is a no-op). + eps: float, a small value used to determine if two vertices are the same. + + Returns: + verts: a tensor of floats, the coordinates of the tesselated vertices. + """ + if not isinstance(v, int): + raise ValueError(f'v {v} must an integer') + tri_weights = compute_tesselation_weights(v) + + verts = [] + for base_face in base_faces: + new_verts = np.matmul(tri_weights, base_verts[base_face, :]) + new_verts /= np.sqrt(np.sum(new_verts ** 2, 1, keepdims=True)) + verts.append(new_verts) + verts = np.concatenate(verts, 0) + + sq_dist = compute_sq_dist(verts.T) + assignment = np.array([np.min(np.argwhere(d <= eps)) for d in sq_dist]) + unique = np.unique(assignment) + verts = verts[unique, :] + + return verts + + +def generate_basis(base_shape, + angular_tesselation, + remove_symmetries=True, + eps=1e-4): + """Generates a 3D basis by tesselating a geometric polyhedron. + + Args: + base_shape: string, the name of the starting polyhedron, must be either + 'icosahedron' or 'octahedron'. + angular_tesselation: int, the number of times to tesselate the polyhedron, + must be >= 1 (a value of 1 is a no-op to the polyhedron). + remove_symmetries: bool, if True then remove the symmetric basis columns, + which is usually a good idea because otherwise projections onto the basis + will have redundant negative copies of each other. + eps: float, a small number used to determine symmetries. + + Returns: + basis: a matrix with shape [3, n]. + """ + if base_shape == 'icosahedron': + a = (np.sqrt(5) + 1) / 2 + verts = np.array([(-1, 0, a), (1, 0, a), (-1, 0, -a), (1, 0, -a), (0, a, 1), + (0, a, -1), (0, -a, 1), (0, -a, -1), (a, 1, 0), + (-a, 1, 0), (a, -1, 0), (-a, -1, 0)]) / np.sqrt(a + 2) + faces = np.array([(0, 4, 1), (0, 9, 4), (9, 5, 4), (4, 5, 8), (4, 8, 1), + (8, 10, 1), (8, 3, 10), (5, 3, 8), (5, 2, 3), (2, 7, 3), + (7, 10, 3), (7, 6, 10), (7, 11, 6), (11, 0, 6), (0, 1, 6), + (6, 1, 10), (9, 0, 11), (9, 11, 2), (9, 2, 5), + (7, 2, 11)]) + verts = tesselate_geodesic(verts, faces, angular_tesselation) + elif base_shape == 'octahedron': + verts = np.array([(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), + (1, 0, 0)]) + corners = np.array(list(itertools.product([-1, 1], repeat=3))) + pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2) + faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1) + verts = tesselate_geodesic(verts, faces, angular_tesselation) + else: + raise ValueError(f'base_shape {base_shape} not supported') + + if remove_symmetries: + # Remove elements of `verts` that are reflections of each other. + match = compute_sq_dist(verts.T, -verts.T) < eps + verts = verts[np.any(np.triu(match), 1), :] + + basis = verts[:, ::-1] + return basis diff --git a/internal/image.py b/internal/image.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7614a3f95f08331d88678390e51c5fb485f1f4 --- /dev/null +++ b/internal/image.py @@ -0,0 +1,126 @@ +import torch +import numpy as np +from internal import math +from skimage.metrics import structural_similarity, peak_signal_noise_ratio +import cv2 + + +def mse_to_psnr(mse): + """Compute PSNR given an MSE (we assume the maximum pixel value is 1).""" + return -10. / np.log(10.) * np.log(mse) + + +def psnr_to_mse(psnr): + """Compute MSE given a PSNR (we assume the maximum pixel value is 1).""" + return np.exp(-0.1 * np.log(10.) * psnr) + + +def ssim_to_dssim(ssim): + """Compute DSSIM given an SSIM.""" + return (1 - ssim) / 2 + + +def dssim_to_ssim(dssim): + """Compute DSSIM given an SSIM.""" + return 1 - 2 * dssim + + +def linear_to_srgb(linear, eps=None): + """Assumes `linear` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.""" + if eps is None: + eps = torch.finfo(linear.dtype).eps + # eps = 1e-3 + + srgb0 = 323 / 25 * linear + srgb1 = (211 * linear.clamp_min(eps) ** (5 / 12) - 11) / 200 + return torch.where(linear <= 0.0031308, srgb0, srgb1) + + +def linear_to_srgb_np(linear, eps=None): + """Assumes `linear` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.""" + if eps is None: + eps = np.finfo(linear.dtype).eps + srgb0 = 323 / 25 * linear + srgb1 = (211 * np.maximum(eps, linear) ** (5 / 12) - 11) / 200 + return np.where(linear <= 0.0031308, srgb0, srgb1) + + +def srgb_to_linear(srgb, eps=None): + """Assumes `srgb` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.""" + if eps is None: + eps = np.finfo(srgb.dtype).eps + linear0 = 25 / 323 * srgb + linear1 = np.maximum(eps, ((200 * srgb + 11) / (211))) ** (12 / 5) + return np.where(srgb <= 0.04045, linear0, linear1) + + +def downsample(img, factor): + """Area downsample img (factor must evenly divide img height and width).""" + sh = img.shape + if not (sh[0] % factor == 0 and sh[1] % factor == 0): + raise ValueError(f'Downsampling factor {factor} does not ' + f'evenly divide image shape {sh[:2]}') + img = img.reshape((sh[0] // factor, factor, sh[1] // factor, factor) + sh[2:]) + img = img.mean((1, 3)) + return img + + +def color_correct(img, ref, num_iters=5, eps=0.5 / 255): + """Warp `img` to match the colors in `ref_img`.""" + if img.shape[-1] != ref.shape[-1]: + raise ValueError( + f'img\'s {img.shape[-1]} and ref\'s {ref.shape[-1]} channels must match' + ) + num_channels = img.shape[-1] + img_mat = img.reshape([-1, num_channels]) + ref_mat = ref.reshape([-1, num_channels]) + is_unclipped = lambda z: (z >= eps) & (z <= (1 - eps)) # z \in [eps, 1-eps]. + mask0 = is_unclipped(img_mat) + # Because the set of saturated pixels may change after solving for a + # transformation, we repeatedly solve a system `num_iters` times and update + # our estimate of which pixels are saturated. + for _ in range(num_iters): + # Construct the left hand side of a linear system that contains a quadratic + # expansion of each pixel of `img`. + a_mat = [] + for c in range(num_channels): + a_mat.append(img_mat[:, c:(c + 1)] * img_mat[:, c:]) # Quadratic term. + a_mat.append(img_mat) # Linear term. + a_mat.append(torch.ones_like(img_mat[:, :1])) # Bias term. + a_mat = torch.cat(a_mat, dim=-1) + warp = [] + for c in range(num_channels): + # Construct the right hand side of a linear system containing each color + # of `ref`. + b = ref_mat[:, c] + # Ignore rows of the linear system that were saturated in the input or are + # saturated in the current corrected color estimate. + mask = mask0[:, c] & is_unclipped(img_mat[:, c]) & is_unclipped(b) + ma_mat = torch.where(mask[:, None], a_mat, torch.zeros_like(a_mat)) + mb = torch.where(mask, b, torch.zeros_like(b)) + w = torch.linalg.lstsq(ma_mat, mb, rcond=-1)[0] + assert torch.all(torch.isfinite(w)) + warp.append(w) + warp = torch.stack(warp, dim=-1) + # Apply the warp to update img_mat. + img_mat = torch.clip(math.matmul(a_mat, warp), 0, 1) + corrected_img = torch.reshape(img_mat, img.shape) + return corrected_img + + +class MetricHarness: + """A helper class for evaluating several error metrics.""" + + def __call__(self, rgb_pred, rgb_gt, name_fn=lambda s: s): + """Evaluate the error between a predicted rgb image and the true image.""" + rgb_pred = (rgb_pred * 255).astype(np.uint8) + rgb_gt = (rgb_gt * 255).astype(np.uint8) + rgb_pred_gray = cv2.cvtColor(rgb_pred, cv2.COLOR_RGB2GRAY) + rgb_gt_gray = cv2.cvtColor(rgb_gt, cv2.COLOR_RGB2GRAY) + psnr = float(peak_signal_noise_ratio(rgb_pred, rgb_gt, data_range=255)) + ssim = float(structural_similarity(rgb_pred_gray, rgb_gt_gray, data_range=255)) + + return { + name_fn('psnr'): psnr, + name_fn('ssim'): ssim, + } diff --git a/internal/math.py b/internal/math.py new file mode 100644 index 0000000000000000000000000000000000000000..e3c6dc6f4eb2d3c0541dab6fe1800e236311c2cf --- /dev/null +++ b/internal/math.py @@ -0,0 +1,133 @@ +import numpy as np +import torch + + +@torch.jit.script +def erf(x): + return torch.sign(x) * torch.sqrt(1 - torch.exp(-4 / torch.pi * x ** 2)) + + +def matmul(a, b): + return (a[..., None] * b[..., None, :, :]).sum(dim=-2) + # B,3,4,1 B,1,4,3 + + # cause nan when fp16 + # return torch.matmul(a, b) + + +def safe_trig_helper(x, fn, t=100 * torch.pi): + """Helper function used by safe_cos/safe_sin: mods x before sin()/cos().""" + return fn(torch.where(torch.abs(x) < t, x, x % t)) + + +def safe_cos(x): + return safe_trig_helper(x, torch.cos) + + +def safe_sin(x): + return safe_trig_helper(x, torch.sin) + + +def safe_exp(x): + return torch.exp(x.clamp_max(88.)) + + +def safe_exp_jvp(primals, tangents): + """Override safe_exp()'s gradient so that it's large when inputs are large.""" + x, = primals + x_dot, = tangents + exp_x = safe_exp(x) + exp_x_dot = exp_x * x_dot + return exp_x, exp_x_dot + + +def log_lerp(t, v0, v1): + """Interpolate log-linearly from `v0` (t=0) to `v1` (t=1).""" + if v0 <= 0 or v1 <= 0: + raise ValueError(f'Interpolants {v0} and {v1} must be positive.') + lv0 = np.log(v0) + lv1 = np.log(v1) + return np.exp(np.clip(t, 0, 1) * (lv1 - lv0) + lv0) + + +def learning_rate_decay(step, + lr_init, + lr_final, + max_steps, + lr_delay_steps=0, + lr_delay_mult=1): + """Continuous learning rate decay function. + + The returned rate is lr_init when step=0 and lr_final when step=max_steps, and + is log-linearly interpolated elsewhere (equivalent to exponential decay). + If lr_delay_steps>0 then the learning rate will be scaled by some smooth + function of lr_delay_mult, such that the initial learning rate is + lr_init*lr_delay_mult at the beginning of optimization but will be eased back + to the normal learning rate when steps>lr_delay_steps. + + Args: + step: int, the current optimization step. + lr_init: float, the initial learning rate. + lr_final: float, the final learning rate. + max_steps: int, the number of steps during optimization. + lr_delay_steps: int, the number of steps to delay the full learning rate. + lr_delay_mult: float, the multiplier on the rate when delaying it. + + Returns: + lr: the learning for current step 'step'. + """ + if lr_delay_steps > 0: + # A kind of reverse cosine decay. + delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( + 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)) + else: + delay_rate = 1. + return delay_rate * log_lerp(step / max_steps, lr_init, lr_final) + + +def sorted_interp(x, xp, fp): + """A TPU-friendly version of interp(), where xp and fp must be sorted.""" + + # Identify the location in `xp` that corresponds to each `x`. + # The final `True` index in `mask` is the start of the matching interval. + mask = x[..., None, :] >= xp[..., :, None] + + def find_interval(x): + # Grab the value where `mask` switches from True to False, and vice versa. + # This approach takes advantage of the fact that `x` is sorted. + x0 = torch.max(torch.where(mask, x[..., None], x[..., :1, None]), -2).values + x1 = torch.min(torch.where(~mask, x[..., None], x[..., -1:, None]), -2).values + return x0, x1 + + fp0, fp1 = find_interval(fp) + xp0, xp1 = find_interval(xp) + + offset = torch.clip(torch.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1) + ret = fp0 + offset * (fp1 - fp0) + return ret + + +def sorted_interp_quad(x, xp, fpdf, fcdf): + """interp in quadratic""" + + # Identify the location in `xp` that corresponds to each `x`. + # The final `True` index in `mask` is the start of the matching interval. + mask = x[..., None, :] >= xp[..., :, None] + + def find_interval(x, return_idx=False): + # Grab the value where `mask` switches from True to False, and vice versa. + # This approach takes advantage of the fact that `x` is sorted. + x0, x0_idx = torch.max(torch.where(mask, x[..., None], x[..., :1, None]), -2) + x1, x1_idx = torch.min(torch.where(~mask, x[..., None], x[..., -1:, None]), -2) + if return_idx: + return x0, x1, x0_idx, x1_idx + return x0, x1 + + fcdf0, fcdf1, fcdf0_idx, fcdf1_idx = find_interval(fcdf, return_idx=True) + fpdf0 = fpdf.take_along_dim(fcdf0_idx, dim=-1) + fpdf1 = fpdf.take_along_dim(fcdf1_idx, dim=-1) + xp0, xp1 = find_interval(xp) + + offset = torch.clip(torch.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1) + ret = fcdf0 + (x - xp0) * (fpdf0 + fpdf1 * offset + fpdf0 * (1 - offset)) / 2 + return ret diff --git a/internal/models.py b/internal/models.py new file mode 100644 index 0000000000000000000000000000000000000000..043c91f60ff0f67016f286fd44602adb633cda30 --- /dev/null +++ b/internal/models.py @@ -0,0 +1,740 @@ +import accelerate +import gin +from internal import coord +from internal import geopoly +from internal import image +from internal import math +from internal import ref_utils +from internal import train_utils +from internal import render +from internal import stepfun +from internal import utils +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils._pytree import tree_map +from tqdm import tqdm +from gridencoder import GridEncoder +from torch_scatter import segment_coo + +gin.config.external_configurable(math.safe_exp, module='math') + + +def set_kwargs(self, kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +@gin.configurable +class Model(nn.Module): + """A mip-Nerf360 model containing all MLPs.""" + num_prop_samples: int = 64 # The number of samples for each proposal level. + num_nerf_samples: int = 32 # The number of samples the final nerf level. + num_levels: int = 3 # The number of sampling levels (3==2 proposals, 1 nerf). + bg_intensity_range = (1., 1.) # The range of background colors. + anneal_slope: float = 10 # Higher = more rapid annealing. + stop_level_grad: bool = True # If True, don't backprop across levels. + use_viewdirs: bool = True # If True, use view directions as input. + raydist_fn = None # The curve used for ray dists. + single_jitter: bool = True # If True, jitter whole rays instead of samples. + dilation_multiplier: float = 0.5 # How much to dilate intervals relatively. + dilation_bias: float = 0.0025 # How much to dilate intervals absolutely. + num_glo_features: int = 0 # GLO vector length, disabled if 0. + num_glo_embeddings: int = 1000 # Upper bound on max number of train images. + learned_exposure_scaling: bool = False # Learned exposure scaling (RawNeRF). + near_anneal_rate = None # How fast to anneal in near bound. + near_anneal_init: float = 0.95 # Where to initialize near bound (in [0, 1]). + single_mlp: bool = False # Use the NerfMLP for all rounds of sampling. + distinct_prop: bool = True # Use the NerfMLP for all rounds of sampling. + resample_padding: float = 0.0 # Dirichlet/alpha "padding" on the histogram. + opaque_background: bool = False # If true, make the background opaque. + power_lambda: float = -1.5 + std_scale: float = 0.5 + prop_desired_grid_size = [512, 2048] + + def __init__(self, config=None, **kwargs): + super().__init__() + set_kwargs(self, kwargs) + self.config = config + + # Construct MLPs. WARNING: Construction order may matter, if MLP weights are + # being regularized. + self.nerf_mlp = NerfMLP(num_glo_features=self.num_glo_features, + num_glo_embeddings=self.num_glo_embeddings) + if self.single_mlp: + self.prop_mlp = self.nerf_mlp + elif not self.distinct_prop: + self.prop_mlp = PropMLP() + else: + for i in range(self.num_levels - 1): + self.register_module(f'prop_mlp_{i}', PropMLP(grid_disired_resolution=self.prop_desired_grid_size[i])) + if self.num_glo_features > 0 and not config.zero_glo: + # Construct/grab GLO vectors for the cameras of each input ray. + self.glo_vecs = nn.Embedding(self.num_glo_embeddings, self.num_glo_features) + + if self.learned_exposure_scaling: + # Setup learned scaling factors for output colors. + max_num_exposures = self.num_glo_embeddings + # Initialize the learned scaling offsets at 0. + self.exposure_scaling_offsets = nn.Embedding(max_num_exposures, 3) + torch.nn.init.zeros_(self.exposure_scaling_offsets.weight) + + def forward( + self, + rand, + batch, + train_frac, + compute_extras, + zero_glo=True, + ): + """The mip-NeRF Model. + + Args: + rand: random number generator (or None for deterministic output). + batch: util.Rays, a pytree of ray origins, directions, and viewdirs. + train_frac: float in [0, 1], what fraction of training is complete. + compute_extras: bool, if True, compute extra quantities besides color. + zero_glo: bool, if True, when using GLO pass in vector of zeros. + + Returns: + ret: list, [*(rgb, distance, acc)] + """ + device = batch['origins'].device + if self.num_glo_features > 0: + if not zero_glo: + # Construct/grab GLO vectors for the cameras of each input ray. + cam_idx = batch['cam_idx'][..., 0] + glo_vec = self.glo_vecs(cam_idx.long()) + else: + glo_vec = torch.zeros(batch['origins'].shape[:-1] + (self.num_glo_features,), device=device) + else: + glo_vec = None + + # Define the mapping from normalized to metric ray distance. + _, s_to_t = coord.construct_ray_warps(self.raydist_fn, batch['near'], batch['far'], self.power_lambda) + + # Initialize the range of (normalized) distances for each ray to [0, 1], + # and assign that single interval a weight of 1. These distances and weights + # will be repeatedly updated as we proceed through sampling levels. + # `near_anneal_rate` can be used to anneal in the near bound at the start + # of training, eg. 0.1 anneals in the bound over the first 10% of training. + if self.near_anneal_rate is None: + init_s_near = 0. + else: + init_s_near = np.clip(1 - train_frac / self.near_anneal_rate, 0, + self.near_anneal_init) + init_s_far = 1. + sdist = torch.cat([ + torch.full_like(batch['near'], init_s_near), + torch.full_like(batch['far'], init_s_far) + ], dim=-1) + weights = torch.ones_like(batch['near']) + prod_num_samples = 1 + + ray_history = [] + renderings = [] + for i_level in range(self.num_levels): + is_prop = i_level < (self.num_levels - 1) + num_samples = self.num_prop_samples if is_prop else self.num_nerf_samples + + # Dilate by some multiple of the expected span of each current interval, + # with some bias added in. + dilation = self.dilation_bias + self.dilation_multiplier * ( + init_s_far - init_s_near) / prod_num_samples + + # Record the product of the number of samples seen so far. + prod_num_samples *= num_samples + + # After the first level (where dilation would be a no-op) optionally + # dilate the interval weights along each ray slightly so that they're + # overestimates, which can reduce aliasing. + use_dilation = self.dilation_bias > 0 or self.dilation_multiplier > 0 + if i_level > 0 and use_dilation: + sdist, weights = stepfun.max_dilate_weights( + sdist, + weights, + dilation, + domain=(init_s_near, init_s_far), + renormalize=True) + sdist = sdist[..., 1:-1] + weights = weights[..., 1:-1] + + # Optionally anneal the weights as a function of training iteration. + if self.anneal_slope > 0: + # Schlick's bias function, see https://arxiv.org/abs/2010.09714 + bias = lambda x, s: (s * x) / ((s - 1) * x + 1) + anneal = bias(train_frac, self.anneal_slope) + else: + anneal = 1. + + # A slightly more stable way to compute weights**anneal. If the distance + # between adjacent intervals is zero then its weight is fixed to 0. + logits_resample = torch.where( + sdist[..., 1:] > sdist[..., :-1], + anneal * torch.log(weights + self.resample_padding), + torch.full_like(sdist[..., :-1], -torch.inf)) + + # Draw sampled intervals from each ray's current weights. + sdist = stepfun.sample_intervals( + rand, + sdist, + logits_resample, + num_samples, + single_jitter=self.single_jitter, + domain=(init_s_near, init_s_far)) + + # Optimization will usually go nonlinear if you propagate gradients + # through sampling. + if self.stop_level_grad: + sdist = sdist.detach() + + # Convert normalized distances to metric distances. + tdist = s_to_t(sdist) + + # Cast our rays, by turning our distance intervals into Gaussians. + means, stds, ts = render.cast_rays( + tdist, + batch['origins'], + batch['directions'], + batch['cam_dirs'], + batch['radii'], + rand, + std_scale=self.std_scale) + + # Push our Gaussians through one of our two MLPs. + mlp = (self.get_submodule( + f'prop_mlp_{i_level}') if self.distinct_prop else self.prop_mlp) if is_prop else self.nerf_mlp + ray_results = mlp( + rand, + means, stds, + viewdirs=batch['viewdirs'] if self.use_viewdirs else None, + imageplane=batch.get('imageplane'), + glo_vec=None if is_prop else glo_vec, + exposure=batch.get('exposure_values'), + ) + if self.config.gradient_scaling: + ray_results['rgb'], ray_results['density'] = train_utils.GradientScaler.apply( + ray_results['rgb'], ray_results['density'], ts.mean(dim=-1)) + + # Get the weights used by volumetric rendering (and our other losses). + weights = render.compute_alpha_weights( + ray_results['density'], + tdist, + batch['directions'], + opaque_background=self.opaque_background, + )[0] + + # Define or sample the background color for each ray. + if self.bg_intensity_range[0] == self.bg_intensity_range[1]: + # If the min and max of the range are equal, just take it. + bg_rgbs = self.bg_intensity_range[0] + elif rand is None: + # If rendering is deterministic, use the midpoint of the range. + bg_rgbs = (self.bg_intensity_range[0] + self.bg_intensity_range[1]) / 2 + else: + # Sample RGB values from the range for each ray. + minval = self.bg_intensity_range[0] + maxval = self.bg_intensity_range[1] + bg_rgbs = torch.rand(weights.shape[:-1] + (3,), device=device) * (maxval - minval) + minval + + # RawNeRF exposure logic. + if batch.get('exposure_idx') is not None: + # Scale output colors by the exposure. + ray_results['rgb'] *= batch['exposure_values'][..., None, :] + if self.learned_exposure_scaling: + exposure_idx = batch['exposure_idx'][..., 0] + # Force scaling offset to always be zero when exposure_idx is 0. + # This constraint fixes a reference point for the scene's brightness. + mask = exposure_idx > 0 + # Scaling is parameterized as an offset from 1. + scaling = 1 + mask[..., None] * self.exposure_scaling_offsets(exposure_idx.long()) + ray_results['rgb'] *= scaling[..., None, :] + + # Render each ray. + rendering = render.volumetric_rendering( + ray_results['rgb'], + weights, + tdist, + bg_rgbs, + batch['far'], + compute_extras, + extras={ + k: v + for k, v in ray_results.items() + if k.startswith('normals') or k in ['roughness'] + }) + + if compute_extras: + # Collect some rays to visualize directly. By naming these quantities + # with `ray_` they get treated differently downstream --- they're + # treated as bags of rays, rather than image chunks. + n = self.config.vis_num_rays + rendering['ray_sdist'] = sdist.reshape([-1, sdist.shape[-1]])[:n, :] + rendering['ray_weights'] = ( + weights.reshape([-1, weights.shape[-1]])[:n, :]) + rgb = ray_results['rgb'] + rendering['ray_rgbs'] = (rgb.reshape((-1,) + rgb.shape[-2:]))[:n, :, :] + + if self.training: + # Compute the hash decay loss for this level. + idx = mlp.encoder.idx + param = mlp.encoder.embeddings + loss_hash_decay = segment_coo(param ** 2, + idx, + torch.zeros(idx.max() + 1, param.shape[-1], device=param.device), + reduce='mean' + ).mean() + ray_results['loss_hash_decay'] = loss_hash_decay + + renderings.append(rendering) + ray_results['sdist'] = sdist.clone() + ray_results['weights'] = weights.clone() + ray_history.append(ray_results) + + if compute_extras: + # Because the proposal network doesn't produce meaningful colors, for + # easier visualization we replace their colors with the final average + # color. + weights = [r['ray_weights'] for r in renderings] + rgbs = [r['ray_rgbs'] for r in renderings] + final_rgb = torch.sum(rgbs[-1] * weights[-1][..., None], dim=-2) + avg_rgbs = [ + torch.broadcast_to(final_rgb[:, None, :], r.shape) for r in rgbs[:-1] + ] + for i in range(len(avg_rgbs)): + renderings[i]['ray_rgbs'] = avg_rgbs[i] + + return renderings, ray_history + + +class MLP(nn.Module): + """A PosEnc MLP.""" + bottleneck_width: int = 256 # The width of the bottleneck vector. + net_depth_viewdirs: int = 2 # The depth of the second part of ML. + net_width_viewdirs: int = 256 # The width of the second part of MLP. + skip_layer_dir: int = 0 # Add a skip connection to 2nd MLP after Nth layers. + num_rgb_channels: int = 3 # The number of RGB channels. + deg_view: int = 4 # Degree of encoding for viewdirs or refdirs. + use_reflections: bool = False # If True, use refdirs instead of viewdirs. + use_directional_enc: bool = False # If True, use IDE to encode directions. + # If False and if use_directional_enc is True, use zero roughness in IDE. + enable_pred_roughness: bool = False + roughness_bias: float = -1. # Shift added to raw roughness pre-activation. + use_diffuse_color: bool = False # If True, predict diffuse & specular colors. + use_specular_tint: bool = False # If True, predict tint. + use_n_dot_v: bool = False # If True, feed dot(n * viewdir) to 2nd MLP. + bottleneck_noise: float = 0.0 # Std. deviation of noise added to bottleneck. + density_bias: float = -1. # Shift added to raw densities pre-activation. + density_noise: float = 0. # Standard deviation of noise added to raw density. + rgb_premultiplier: float = 1. # Premultiplier on RGB before activation. + rgb_bias: float = 0. # The shift added to raw colors pre-activation. + rgb_padding: float = 0.001 # Padding added to the RGB outputs. + enable_pred_normals: bool = False # If True compute predicted normals. + disable_density_normals: bool = False # If True don't compute normals. + disable_rgb: bool = False # If True don't output RGB. + warp_fn = 'contract' + num_glo_features: int = 0 # GLO vector length, disabled if 0. + num_glo_embeddings: int = 1000 # Upper bound on max number of train images. + scale_featurization: bool = False + grid_num_levels: int = 10 + grid_level_interval: int = 2 + grid_level_dim: int = 4 + grid_base_resolution: int = 16 + grid_disired_resolution: int = 8192 + grid_log2_hashmap_size: int = 21 + net_width_glo: int = 128 # The width of the second part of MLP. + net_depth_glo: int = 2 # The width of the second part of MLP. + + def __init__(self, **kwargs): + super().__init__() + set_kwargs(self, kwargs) + # Make sure that normals are computed if reflection direction is used. + if self.use_reflections and not (self.enable_pred_normals or + not self.disable_density_normals): + raise ValueError('Normals must be computed for reflection directions.') + + # Precompute and define viewdir or refdir encoding function. + if self.use_directional_enc: + self.dir_enc_fn = ref_utils.generate_ide_fn(self.deg_view) + dim_dir_enc = self.dir_enc_fn(torch.zeros(1, 3), torch.zeros(1, 1)).shape[-1] + else: + + def dir_enc_fn(direction, _): + return coord.pos_enc( + direction, min_deg=0, max_deg=self.deg_view, append_identity=True) + + self.dir_enc_fn = dir_enc_fn + dim_dir_enc = self.dir_enc_fn(torch.zeros(1, 3), None).shape[-1] + self.grid_num_levels = int( + np.log(self.grid_disired_resolution / self.grid_base_resolution) / np.log(self.grid_level_interval)) + 1 + self.encoder = GridEncoder(input_dim=3, + num_levels=self.grid_num_levels, + level_dim=self.grid_level_dim, + base_resolution=self.grid_base_resolution, + desired_resolution=self.grid_disired_resolution, + log2_hashmap_size=self.grid_log2_hashmap_size, + gridtype='hash', + align_corners=False) + last_dim = self.encoder.output_dim + if self.scale_featurization: + last_dim += self.encoder.num_levels + self.density_layer = nn.Sequential(nn.Linear(last_dim, 64), + nn.ReLU(), + nn.Linear(64, + 1 if self.disable_rgb else self.bottleneck_width)) # Hardcoded to a single channel. + last_dim = 1 if self.disable_rgb and not self.enable_pred_normals else self.bottleneck_width + if self.enable_pred_normals: + self.normal_layer = nn.Linear(last_dim, 3) + + if not self.disable_rgb: + if self.use_diffuse_color: + self.diffuse_layer = nn.Linear(last_dim, self.num_rgb_channels) + + if self.use_specular_tint: + self.specular_layer = nn.Linear(last_dim, 3) + + if self.enable_pred_roughness: + self.roughness_layer = nn.Linear(last_dim, 1) + + # Output of the first part of MLP. + if self.bottleneck_width > 0: + last_dim_rgb = self.bottleneck_width + else: + last_dim_rgb = 0 + + last_dim_rgb += dim_dir_enc + + if self.use_n_dot_v: + last_dim_rgb += 1 + + if self.num_glo_features > 0: + last_dim_glo = self.num_glo_features + for i in range(self.net_depth_glo - 1): + self.register_module(f"lin_glo_{i}", nn.Linear(last_dim_glo, self.net_width_glo)) + last_dim_glo = self.net_width_glo + self.register_module(f"lin_glo_{self.net_depth_glo - 1}", + nn.Linear(last_dim_glo, self.bottleneck_width * 2)) + + input_dim_rgb = last_dim_rgb + for i in range(self.net_depth_viewdirs): + lin = nn.Linear(last_dim_rgb, self.net_width_viewdirs) + torch.nn.init.kaiming_uniform_(lin.weight) + self.register_module(f"lin_second_stage_{i}", lin) + last_dim_rgb = self.net_width_viewdirs + if i == self.skip_layer_dir: + last_dim_rgb += input_dim_rgb + self.rgb_layer = nn.Linear(last_dim_rgb, self.num_rgb_channels) + + def predict_density(self, means, stds, rand=False, no_warp=False): + """Helper function to output density.""" + # Encode input positions + if self.warp_fn is not None and not no_warp: + means, stds = coord.track_linearize(self.warp_fn, means, stds) + # contract [-2, 2] to [-1, 1] + bound = 2 + means = means / bound + stds = stds / bound + features = self.encoder(means, bound=1).unflatten(-1, (self.encoder.num_levels, -1)) + weights = torch.erf(1 / torch.sqrt(8 * stds[..., None] ** 2 * self.encoder.grid_sizes ** 2)) + features = (features * weights[..., None]).mean(dim=-3).flatten(-2, -1) + if self.scale_featurization: + with torch.no_grad(): + vl2mean = segment_coo((self.encoder.embeddings ** 2).sum(-1), + self.encoder.idx, + torch.zeros(self.grid_num_levels, device=weights.device), + self.grid_num_levels, + reduce='mean' + ) + featurized_w = (2 * weights.mean(dim=-2) - 1) * (self.encoder.init_std ** 2 + vl2mean).sqrt() + features = torch.cat([features, featurized_w], dim=-1) + x = self.density_layer(features) + raw_density = x[..., 0] # Hardcoded to a single channel. + # Add noise to regularize the density predictions if needed. + if rand and (self.density_noise > 0): + raw_density += self.density_noise * torch.randn_like(raw_density) + return raw_density, x, means.mean(dim=-2) + + def forward(self, + rand, + means, stds, + viewdirs=None, + imageplane=None, + glo_vec=None, + exposure=None, + no_warp=False): + """Evaluate the MLP. + + Args: + rand: if random . + means: [..., n, 3], coordinate means. + stds: [..., n], coordinate stds. + viewdirs: [..., 3], if not None, this variable will + be part of the input to the second part of the MLP concatenated with the + output vector of the first part of the MLP. If None, only the first part + of the MLP will be used with input x. In the original paper, this + variable is the view direction. + imageplane:[batch, 2], xy image plane coordinates + for each ray in the batch. Useful for image plane operations such as a + learned vignette mapping. + glo_vec: [..., num_glo_features], The GLO vector for each ray. + exposure: [..., 1], exposure value (shutter_speed * ISO) for each ray. + + Returns: + rgb: [..., num_rgb_channels]. + density: [...]. + normals: [..., 3], or None. + normals_pred: [..., 3], or None. + roughness: [..., 1], or None. + """ + if self.disable_density_normals: + raw_density, x, means_contract = self.predict_density(means, stds, rand=rand, no_warp=no_warp) + raw_grad_density = None + normals = None + else: + with torch.enable_grad(): + means.requires_grad_(True) + raw_density, x, means_contract = self.predict_density(means, stds, rand=rand, no_warp=no_warp) + d_output = torch.ones_like(raw_density, requires_grad=False, device=raw_density.device) + raw_grad_density = torch.autograd.grad( + outputs=raw_density, + inputs=means, + grad_outputs=d_output, + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + raw_grad_density = raw_grad_density.mean(-2) + # Compute normal vectors as negative normalized density gradient. + # We normalize the gradient of raw (pre-activation) density because + # it's the same as post-activation density, but is more numerically stable + # when the activation function has a steep or flat gradient. + normals = -ref_utils.l2_normalize(raw_grad_density) + + if self.enable_pred_normals: + grad_pred = self.normal_layer(x) + + # Normalize negative predicted gradients to get predicted normal vectors. + normals_pred = -ref_utils.l2_normalize(grad_pred) + normals_to_use = normals_pred + else: + grad_pred = None + normals_pred = None + normals_to_use = normals + + # Apply bias and activation to raw density + density = F.softplus(raw_density + self.density_bias) + + roughness = None + if self.disable_rgb: + rgb = torch.zeros(density.shape + (3,), device=density.device) + else: + if viewdirs is not None: + # Predict diffuse color. + if self.use_diffuse_color: + raw_rgb_diffuse = self.diffuse_layer(x) + + if self.use_specular_tint: + tint = torch.sigmoid(self.specular_layer(x)) + + if self.enable_pred_roughness: + raw_roughness = self.roughness_layer(x) + roughness = (F.softplus(raw_roughness + self.roughness_bias)) + + # Output of the first part of MLP. + if self.bottleneck_width > 0: + bottleneck = x + # Add bottleneck noise. + if rand and (self.bottleneck_noise > 0): + bottleneck += self.bottleneck_noise * torch.randn_like(bottleneck) + + # Append GLO vector if used. + if glo_vec is not None: + for i in range(self.net_depth_glo): + glo_vec = self.get_submodule(f"lin_glo_{i}")(glo_vec) + if i != self.net_depth_glo - 1: + glo_vec = F.relu(glo_vec) + glo_vec = torch.broadcast_to(glo_vec[..., None, :], + bottleneck.shape[:-1] + glo_vec.shape[-1:]) + scale, shift = glo_vec.chunk(2, dim=-1) + bottleneck = bottleneck * torch.exp(scale) + shift + + x = [bottleneck] + else: + x = [] + + # Encode view (or reflection) directions. + if self.use_reflections: + # Compute reflection directions. Note that we flip viewdirs before + # reflecting, because they point from the camera to the point, + # whereas ref_utils.reflect() assumes they point toward the camera. + # Returned refdirs then point from the point to the environment. + refdirs = ref_utils.reflect(-viewdirs[..., None, :], normals_to_use) + # Encode reflection directions. + dir_enc = self.dir_enc_fn(refdirs, roughness) + else: + # Encode view directions. + dir_enc = self.dir_enc_fn(viewdirs, roughness) + dir_enc = torch.broadcast_to( + dir_enc[..., None, :], + bottleneck.shape[:-1] + (dir_enc.shape[-1],)) + + # Append view (or reflection) direction encoding to bottleneck vector. + x.append(dir_enc) + + # Append dot product between normal vectors and view directions. + if self.use_n_dot_v: + dotprod = torch.sum( + normals_to_use * viewdirs[..., None, :], dim=-1, keepdim=True) + x.append(dotprod) + + # Concatenate bottleneck, directional encoding, and GLO. + x = torch.cat(x, dim=-1) + # Output of the second part of MLP. + inputs = x + for i in range(self.net_depth_viewdirs): + x = self.get_submodule(f"lin_second_stage_{i}")(x) + x = F.relu(x) + if i == self.skip_layer_dir: + x = torch.cat([x, inputs], dim=-1) + # If using diffuse/specular colors, then `rgb` is treated as linear + # specular color. Otherwise it's treated as the color itself. + rgb = torch.sigmoid(self.rgb_premultiplier * + self.rgb_layer(x) + + self.rgb_bias) + + if self.use_diffuse_color: + # Initialize linear diffuse color around 0.25, so that the combined + # linear color is initialized around 0.5. + diffuse_linear = torch.sigmoid(raw_rgb_diffuse - np.log(3.0)) + if self.use_specular_tint: + specular_linear = tint * rgb + else: + specular_linear = 0.5 * rgb + + # Combine specular and diffuse components and tone map to sRGB. + rgb = torch.clip(image.linear_to_srgb(specular_linear + diffuse_linear), 0.0, 1.0) + + # Apply padding, mapping color to [-rgb_padding, 1+rgb_padding]. + rgb = rgb * (1 + 2 * self.rgb_padding) - self.rgb_padding + + return dict( + coord=means_contract, + density=density, + rgb=rgb, + raw_grad_density=raw_grad_density, + grad_pred=grad_pred, + normals=normals, + normals_pred=normals_pred, + roughness=roughness, + ) + + +@gin.configurable +class NerfMLP(MLP): + pass + + +@gin.configurable +class PropMLP(MLP): + pass + + +@torch.no_grad() +def render_image(model, + accelerator: accelerate.Accelerator, + batch, + rand, + train_frac, + config, + verbose=True, + return_weights=False): + """Render all the pixels of an image (in test mode). + + Args: + render_fn: function, jit-ed render function mapping (rand, batch) -> pytree. + accelerator: used for DDP. + batch: a `Rays` pytree, the rays to be rendered. + rand: if random + config: A Config class. + + Returns: + rgb: rendered color image. + disp: rendered disparity image. + acc: rendered accumulated weights per pixel. + """ + model.eval() + + height, width = batch['origins'].shape[:2] + num_rays = height * width + batch = {k: v.reshape((num_rays, -1)) for k, v in batch.items() if v is not None} + + global_rank = accelerator.process_index + chunks = [] + idx0s = tqdm(range(0, num_rays, config.render_chunk_size), + desc="Rendering chunk", leave=False, + disable=not (accelerator.is_main_process and verbose)) + + for i_chunk, idx0 in enumerate(idx0s): + chunk_batch = tree_map(lambda r: r[idx0:idx0 + config.render_chunk_size], batch) + actual_chunk_size = chunk_batch['origins'].shape[0] + rays_remaining = actual_chunk_size % accelerator.num_processes + if rays_remaining != 0: + padding = accelerator.num_processes - rays_remaining + chunk_batch = tree_map(lambda v: torch.cat([v, torch.zeros_like(v[-padding:])], dim=0), chunk_batch) + else: + padding = 0 + # After padding the number of chunk_rays is always divisible by host_count. + rays_per_host = chunk_batch['origins'].shape[0] // accelerator.num_processes + start, stop = global_rank * rays_per_host, (global_rank + 1) * rays_per_host + chunk_batch = tree_map(lambda r: r[start:stop], chunk_batch) + + with accelerator.autocast(): + chunk_renderings, ray_history = model(rand, + chunk_batch, + train_frac=train_frac, + compute_extras=True, + zero_glo=True) + + gather = lambda v: accelerator.gather(v.contiguous())[:-padding] \ + if padding > 0 else accelerator.gather(v.contiguous()) + # Unshard the renderings. + chunk_renderings = tree_map(gather, chunk_renderings) + + # Gather the final pass for 2D buffers and all passes for ray bundles. + chunk_rendering = chunk_renderings[-1] + for k in chunk_renderings[0]: + if k.startswith('ray_'): + chunk_rendering[k] = [r[k] for r in chunk_renderings] + + if return_weights: + chunk_rendering['weights'] = gather(ray_history[-1]['weights']) + chunk_rendering['coord'] = gather(ray_history[-1]['coord']) + chunks.append(chunk_rendering) + + # Concatenate all chunks within each leaf of a single pytree. + rendering = {} + for k in chunks[0].keys(): + if isinstance(chunks[0][k], list): + rendering[k] = [] + for i in range(len(chunks[0][k])): + rendering[k].append(torch.cat([item[k][i] for item in chunks])) + else: + rendering[k] = torch.cat([item[k] for item in chunks]) + + for k, z in rendering.items(): + if not k.startswith('ray_'): + # Reshape 2D buffers into original image shape. + rendering[k] = z.reshape((height, width) + z.shape[1:]) + + # After all of the ray bundles have been concatenated together, extract a + # new random bundle (deterministically) from the concatenation that is the + # same size as one of the individual bundles. + keys = [k for k in rendering if k.startswith('ray_')] + if keys: + num_rays = rendering[keys[0]][0].shape[0] + ray_idx = torch.randperm(num_rays) + ray_idx = ray_idx[:config.vis_num_rays] + for k in keys: + rendering[k] = [r[ray_idx] for r in rendering[k]] + model.train() + return rendering diff --git a/internal/pycolmap/.gitignore b/internal/pycolmap/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..cbf0fa605e414d9d685de0d76b7370718c4cc13a --- /dev/null +++ b/internal/pycolmap/.gitignore @@ -0,0 +1,2 @@ +*.pyc +*.sw* diff --git a/internal/pycolmap/LICENSE.txt b/internal/pycolmap/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..5156d3d49e0c312561c59680658b6261f635abe3 --- /dev/null +++ b/internal/pycolmap/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 True Price, UNC Chapel Hill + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/internal/pycolmap/README.md b/internal/pycolmap/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2b8ca6b36cbe8091ac78d954c2ea723303eb0397 --- /dev/null +++ b/internal/pycolmap/README.md @@ -0,0 +1,4 @@ +# pycolmap +Python interface for COLMAP reconstructions, plus some convenient scripts for loading/modifying/converting reconstructions. + +This code does not, however, run reconstruction -- it only provides a convenient interface for handling COLMAP's output. diff --git a/internal/pycolmap/pycolmap/__init__.py b/internal/pycolmap/pycolmap/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9c194c20574748574d1a9323a051263f0a4bae75 --- /dev/null +++ b/internal/pycolmap/pycolmap/__init__.py @@ -0,0 +1,5 @@ +from camera import Camera +from database import COLMAPDatabase +from image import Image +from scene_manager import SceneManager +from rotation import Quaternion, DualQuaternion diff --git a/internal/pycolmap/pycolmap/camera.py b/internal/pycolmap/pycolmap/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..ab5e5b184a03e11272c974d4349e09867001cfb4 --- /dev/null +++ b/internal/pycolmap/pycolmap/camera.py @@ -0,0 +1,259 @@ +# Author: True Price + +import numpy as np + +from scipy.optimize import root + + +#------------------------------------------------------------------------------- +# +# camera distortion functions for arrays of size (..., 2) +# +#------------------------------------------------------------------------------- + +def simple_radial_distortion(camera, x): + return x * (1. + camera.k1 * np.square(x).sum(axis=-1, keepdims=True)) + +def radial_distortion(camera, x): + r_sq = np.square(x).sum(axis=-1, keepdims=True) + return x * (1. + r_sq * (camera.k1 + camera.k2 * r_sq)) + +def opencv_distortion(camera, x): + x_sq = np.square(x) + xy = np.prod(x, axis=-1, keepdims=True) + r_sq = x_sq.sum(axis=-1, keepdims=True) + + return x * (1. + r_sq * (camera.k1 + camera.k2 * r_sq)) + np.concatenate(( + 2. * camera.p1 * xy + camera.p2 * (r_sq + 2. * x_sq), + camera.p1 * (r_sq + 2. * y_sq) + 2. * camera.p2 * xy), + axis=-1) + + +#------------------------------------------------------------------------------- +# +# Camera +# +#------------------------------------------------------------------------------- + +class Camera: + @staticmethod + def GetNumParams(type_): + if type_ == 0 or type_ == 'SIMPLE_PINHOLE': + return 3 + if type_ == 1 or type_ == 'PINHOLE': + return 4 + if type_ == 2 or type_ == 'SIMPLE_RADIAL': + return 4 + if type_ == 3 or type_ == 'RADIAL': + return 5 + if type_ == 4 or type_ == 'OPENCV': + return 8 + #if type_ == 5 or type_ == 'OPENCV_FISHEYE': + # return 8 + #if type_ == 6 or type_ == 'FULL_OPENCV': + # return 12 + #if type_ == 7 or type_ == 'FOV': + # return 5 + #if type_ == 8 or type_ == 'SIMPLE_RADIAL_FISHEYE': + # return 4 + #if type_ == 9 or type_ == 'RADIAL_FISHEYE': + # return 5 + #if type_ == 10 or type_ == 'THIN_PRISM_FISHEYE': + # return 12 + + # TODO: not supporting other camera types, currently + raise Exception('Camera type not supported') + + + #--------------------------------------------------------------------------- + + @staticmethod + def GetNameFromType(type_): + if type_ == 0: return 'SIMPLE_PINHOLE' + if type_ == 1: return 'PINHOLE' + if type_ == 2: return 'SIMPLE_RADIAL' + if type_ == 3: return 'RADIAL' + if type_ == 4: return 'OPENCV' + #if type_ == 5: return 'OPENCV_FISHEYE' + #if type_ == 6: return 'FULL_OPENCV' + #if type_ == 7: return 'FOV' + #if type_ == 8: return 'SIMPLE_RADIAL_FISHEYE' + #if type_ == 9: return 'RADIAL_FISHEYE' + #if type_ == 10: return 'THIN_PRISM_FISHEYE' + + raise Exception('Camera type not supported') + + + #--------------------------------------------------------------------------- + + def __init__(self, type_, width_, height_, params): + self.width = width_ + self.height = height_ + + if type_ == 0 or type_ == 'SIMPLE_PINHOLE': + self.fx, self.cx, self.cy = params + self.fy = self.fx + self.distortion_func = None + self.camera_type = 0 + + elif type_ == 1 or type_ == 'PINHOLE': + self.fx, self.fy, self.cx, self.cy = params + self.distortion_func = None + self.camera_type = 1 + + elif type_ == 2 or type_ == 'SIMPLE_RADIAL': + self.fx, self.cx, self.cy, self.k1 = params + self.fy = self.fx + self.distortion_func = simple_radial_distortion + self.camera_type = 2 + + elif type_ == 3 or type_ == 'RADIAL': + self.fx, self.cx, self.cy, self.k1, self.k2 = params + self.fy = self.fx + self.distortion_func = radial_distortion + self.camera_type = 3 + + elif type_ == 4 or type_ == 'OPENCV': + self.fx, self.fy, self.cx, self.cy = params[:4] + self.k1, self.k2, self.p1, self.p2 = params[4:] + self.distortion_func = opencv_distortion + self.camera_type = 4 + + else: + raise Exception('Camera type not supported') + + + #--------------------------------------------------------------------------- + + def __str__(self): + s = (self.GetNameFromType(self.camera_type) + + ' {} {} {}'.format(self.width, self.height, self.fx)) + + if self.camera_type in (1, 4): # PINHOLE, OPENCV + s += ' {}'.format(self.fy) + + s += ' {} {}'.format(self.cx, self.cy) + + if self.camera_type == 2: # SIMPLE_RADIAL + s += ' {}'.format(self.k1) + + elif self.camera_type == 3: # RADIAL + s += ' {} {}'.format(self.k1, self.k2) + + elif self.camera_type == 4: # OPENCV + s += ' {} {} {} {}'.format(self.k1, self.k2, self.p1, self.p2) + + return s + + + #--------------------------------------------------------------------------- + + # return the camera parameters in the same order as the colmap output format + def get_params(self): + if self.camera_type == 0: + return np.array((self.fx, self.cx, self.cy)) + if self.camera_type == 1: + return np.array((self.fx, self.fy, self.cx, self.cy)) + if self.camera_type == 2: + return np.array((self.fx, self.cx, self.cy, self.k1)) + if self.camera_type == 3: + return np.array((self.fx, self.cx, self.cy, self.k1, self.k2)) + if self.camera_type == 4: + return np.array((self.fx, self.fy, self.cx, self.cy, self.k1, + self.k2, self.p1, self.p2)) + + + #--------------------------------------------------------------------------- + + def get_camera_matrix(self): + return np.array( + ((self.fx, 0, self.cx), (0, self.fy, self.cy), (0, 0, 1))) + + def get_inverse_camera_matrix(self): + return np.array( + ((1. / self.fx, 0, -self.cx / self.fx), + (0, 1. / self.fy, -self.cy / self.fy), + (0, 0, 1))) + + @property + def K(self): + return self.get_camera_matrix() + + @property + def K_inv(self): + return self.get_inverse_camera_matrix() + + #--------------------------------------------------------------------------- + + # return the inverse camera matrix + def get_inv_camera_matrix(self): + inv_fx, inv_fy = 1. / self.fx, 1. / self.fy + return np.array(((inv_fx, 0, -inv_fx * self.cx), + (0, inv_fy, -inv_fy * self.cy), + (0, 0, 1))) + + + #--------------------------------------------------------------------------- + + # return an (x, y) pixel coordinate grid for this camera + def get_image_grid(self): + xmin = (0.5 - self.cx) / self.fx + xmax = (self.width - 0.5 - self.cx) / self.fx + ymin = (0.5 - self.cy) / self.fy + ymax = (self.height - 0.5 - self.cy) / self.fy + return np.meshgrid(np.linspace(xmin, xmax, self.width), + np.linspace(ymin, ymax, self.height)) + + + #--------------------------------------------------------------------------- + + # x: array of shape (N,2) or (2,) + # normalized: False if the input points are in pixel coordinates + # denormalize: True if the points should be put back into pixel coordinates + def distort_points(self, x, normalized=True, denormalize=True): + x = np.atleast_2d(x) + + # put the points into normalized camera coordinates + if not normalized: + x -= np.array([[self.cx, self.cy]]) + x /= np.array([[self.fx, self.fy]]) + + # distort, if necessary + if self.distortion_func is not None: + x = self.distortion_func(self, x) + + if denormalize: + x *= np.array([[self.fx, self.fy]]) + x += np.array([[self.cx, self.cy]]) + + return x + + + #--------------------------------------------------------------------------- + + # x: array of shape (N1,N2,...,2), (N,2), or (2,) + # normalized: False if the input points are in pixel coordinates + # denormalize: True if the points should be put back into pixel coordinates + def undistort_points(self, x, normalized=False, denormalize=True): + x = np.atleast_2d(x) + + # put the points into normalized camera coordinates + if not normalized: + x = x - np.array([self.cx, self.cy]) # creates a copy + x /= np.array([self.fx, self.fy]) + + # undistort, if necessary + if self.distortion_func is not None: + def objective(xu): + return (x - self.distortion_func(self, xu.reshape(*x.shape)) + ).ravel() + + xu = root(objective, x).x.reshape(*x.shape) + else: + xu = x + + if denormalize: + xu *= np.array([[self.fx, self.fy]]) + xu += np.array([[self.cx, self.cy]]) + + return xu diff --git a/internal/pycolmap/pycolmap/database.py b/internal/pycolmap/pycolmap/database.py new file mode 100644 index 0000000000000000000000000000000000000000..c11948d8ec464c567c1581e6dd588350efa4c7a5 --- /dev/null +++ b/internal/pycolmap/pycolmap/database.py @@ -0,0 +1,340 @@ +import numpy as np +import os +import sqlite3 + + +#------------------------------------------------------------------------------- +# convert SQLite BLOBs to/from numpy arrays + +def array_to_blob(arr): + return np.getbuffer(arr) + +def blob_to_array(blob, dtype, shape=(-1,)): + return np.frombuffer(blob, dtype).reshape(*shape) + + +#------------------------------------------------------------------------------- +# convert to/from image pair ids + +MAX_IMAGE_ID = 2**31 - 1 + +def get_pair_id(image_id1, image_id2): + if image_id1 > image_id2: + image_id1, image_id2 = image_id2, image_id1 + return image_id1 * MAX_IMAGE_ID + image_id2 + + +def get_image_ids_from_pair_id(pair_id): + image_id2 = pair_id % MAX_IMAGE_ID + return (pair_id - image_id2) / MAX_IMAGE_ID, image_id2 + + +#------------------------------------------------------------------------------- +# create table commands + +CREATE_CAMERAS_TABLE = """CREATE TABLE IF NOT EXISTS cameras ( + camera_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + model INTEGER NOT NULL, + width INTEGER NOT NULL, + height INTEGER NOT NULL, + params BLOB, + prior_focal_length INTEGER NOT NULL)""" + +CREATE_DESCRIPTORS_TABLE = """CREATE TABLE IF NOT EXISTS descriptors ( + image_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" + +CREATE_IMAGES_TABLE = """CREATE TABLE IF NOT EXISTS images ( + image_id INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + name TEXT NOT NULL UNIQUE, + camera_id INTEGER NOT NULL, + prior_qw REAL, + prior_qx REAL, + prior_qy REAL, + prior_qz REAL, + prior_tx REAL, + prior_ty REAL, + prior_tz REAL, + CONSTRAINT image_id_check CHECK(image_id >= 0 and image_id < 2147483647), + FOREIGN KEY(camera_id) REFERENCES cameras(camera_id))""" + +CREATE_INLIER_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS two_view_geometries ( + pair_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + config INTEGER NOT NULL, + F BLOB, + E BLOB, + H BLOB)""" + +CREATE_KEYPOINTS_TABLE = """CREATE TABLE IF NOT EXISTS keypoints ( + image_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB, + FOREIGN KEY(image_id) REFERENCES images(image_id) ON DELETE CASCADE)""" + +CREATE_MATCHES_TABLE = """CREATE TABLE IF NOT EXISTS matches ( + pair_id INTEGER PRIMARY KEY NOT NULL, + rows INTEGER NOT NULL, + cols INTEGER NOT NULL, + data BLOB)""" + +CREATE_NAME_INDEX = \ + "CREATE UNIQUE INDEX IF NOT EXISTS index_name ON images(name)" + +CREATE_ALL = "; ".join([CREATE_CAMERAS_TABLE, CREATE_DESCRIPTORS_TABLE, + CREATE_IMAGES_TABLE, CREATE_INLIER_MATCHES_TABLE, CREATE_KEYPOINTS_TABLE, + CREATE_MATCHES_TABLE, CREATE_NAME_INDEX]) + + +#------------------------------------------------------------------------------- +# functional interface for adding objects + +def add_camera(db, model, width, height, params, prior_focal_length=False, + camera_id=None): + # TODO: Parameter count checks + params = np.asarray(params, np.float64) + db.execute("INSERT INTO cameras VALUES (?, ?, ?, ?, ?, ?)", + (camera_id, model, width, height, array_to_blob(params), + prior_focal_length)) + + +def add_descriptors(db, image_id, descriptors): + descriptors = np.ascontiguousarray(descriptors, np.uint8) + db.execute("INSERT INTO descriptors VALUES (?, ?, ?, ?)", + (image_id,) + descriptors.shape + (array_to_blob(descriptors),)) + + +def add_image(db, name, camera_id, prior_q=np.zeros(4), prior_t=np.zeros(3), + image_id=None): + db.execute("INSERT INTO images VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + (image_id, name, camera_id, prior_q[0], prior_q[1], prior_q[2], + prior_q[3], prior_t[0], prior_t[1], prior_t[2])) + + +# config: defaults to fundamental matrix +def add_inlier_matches(db, image_id1, image_id2, matches, config=2, F=None, + E=None, H=None): + assert(len(matches.shape) == 2) + assert(matches.shape[1] == 2) + + if image_id1 > image_id2: + matches = matches[:,::-1] + + if F is not None: + F = np.asarray(F, np.float64) + if E is not None: + E = np.asarray(E, np.float64) + if H is not None: + H = np.asarray(H, np.float64) + + pair_id = get_pair_id(image_id1, image_id2) + matches = np.asarray(matches, np.uint32) + db.execute("INSERT INTO inlier_matches VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + (pair_id,) + matches.shape + (array_to_blob(matches), config, F, E, H)) + + +def add_keypoints(db, image_id, keypoints): + assert(len(keypoints.shape) == 2) + assert(keypoints.shape[1] in [2, 4, 6]) + + keypoints = np.asarray(keypoints, np.float32) + db.execute("INSERT INTO keypoints VALUES (?, ?, ?, ?)", + (image_id,) + keypoints.shape + (array_to_blob(keypoints),)) + + +# config: defaults to fundamental matrix +def add_matches(db, image_id1, image_id2, matches): + assert(len(matches.shape) == 2) + assert(matches.shape[1] == 2) + + if image_id1 > image_id2: + matches = matches[:,::-1] + + pair_id = get_pair_id(image_id1, image_id2) + matches = np.asarray(matches, np.uint32) + db.execute("INSERT INTO matches VALUES (?, ?, ?, ?)", + (pair_id,) + matches.shape + (array_to_blob(matches),)) + + +#------------------------------------------------------------------------------- +# simple functional interface + +class COLMAPDatabase(sqlite3.Connection): + @staticmethod + def connect(database_path): + return sqlite3.connect(database_path, factory=COLMAPDatabase) + + + def __init__(self, *args, **kwargs): + super(COLMAPDatabase, self).__init__(*args, **kwargs) + + self.initialize_tables = lambda: self.executescript(CREATE_ALL) + + self.initialize_cameras = \ + lambda: self.executescript(CREATE_CAMERAS_TABLE) + self.initialize_descriptors = \ + lambda: self.executescript(CREATE_DESCRIPTORS_TABLE) + self.initialize_images = \ + lambda: self.executescript(CREATE_IMAGES_TABLE) + self.initialize_inlier_matches = \ + lambda: self.executescript(CREATE_INLIER_MATCHES_TABLE) + self.initialize_keypoints = \ + lambda: self.executescript(CREATE_KEYPOINTS_TABLE) + self.initialize_matches = \ + lambda: self.executescript(CREATE_MATCHES_TABLE) + + self.create_name_index = lambda: self.executescript(CREATE_NAME_INDEX) + + + add_camera = add_camera + add_descriptors = add_descriptors + add_image = add_image + add_inlier_matches = add_inlier_matches + add_keypoints = add_keypoints + add_matches = add_matches + + +#------------------------------------------------------------------------------- + +def main(args): + import os + + if os.path.exists(args.database_path): + print("Error: database path already exists -- will not modify it.") + exit() + + db = COLMAPDatabase.connect(args.database_path) + + # + # for convenience, try creating all the tables upfront + # + + db.initialize_tables() + + + # + # create dummy cameras + # + + model1, w1, h1, params1 = 0, 1024, 768, np.array((1024., 512., 384.)) + model2, w2, h2, params2 = 2, 1024, 768, np.array((1024., 512., 384., 0.1)) + + db.add_camera(model1, w1, h1, params1) + db.add_camera(model2, w2, h2, params2) + + + # + # create dummy images + # + + db.add_image("image1.png", 0) + db.add_image("image2.png", 0) + db.add_image("image3.png", 2) + db.add_image("image4.png", 2) + + + # + # create dummy keypoints; note that COLMAP supports 2D keypoints (x, y), + # 4D keypoints (x, y, theta, scale), and 6D affine keypoints + # (x, y, a_11, a_12, a_21, a_22) + # + + N = 1000 + kp1 = np.random.rand(N, 2) * (1024., 768.) + kp2 = np.random.rand(N, 2) * (1024., 768.) + kp3 = np.random.rand(N, 2) * (1024., 768.) + kp4 = np.random.rand(N, 2) * (1024., 768.) + + db.add_keypoints(1, kp1) + db.add_keypoints(2, kp2) + db.add_keypoints(3, kp3) + db.add_keypoints(4, kp4) + + + # + # create dummy matches + # + + M = 50 + m12 = np.random.randint(N, size=(M, 2)) + m23 = np.random.randint(N, size=(M, 2)) + m34 = np.random.randint(N, size=(M, 2)) + + db.add_matches(1, 2, m12) + db.add_matches(2, 3, m23) + db.add_matches(3, 4, m34) + + + # + # check cameras + # + + rows = db.execute("SELECT * FROM cameras") + + camera_id, model, width, height, params, prior = next(rows) + params = blob_to_array(params, np.float32) + assert model == model1 and width == w1 and height == h1 + assert np.allclose(params, params1) + + camera_id, model, width, height, params, prior = next(rows) + params = blob_to_array(params, np.float32) + assert model == model2 and width == w2 and height == h2 + assert np.allclose(params, params2) + + + # + # check keypoints + # + + kps = dict( + (image_id, blob_to_array(data, np.float32, (-1, 2))) + for image_id, data in db.execute( + "SELECT image_id, data FROM keypoints")) + + assert np.allclose(kps[1], kp1) + assert np.allclose(kps[2], kp2) + assert np.allclose(kps[3], kp3) + assert np.allclose(kps[4], kp4) + + + # + # check matches + # + + pair_ids = [get_pair_id(*pair) for pair in [(1, 2), (2, 3), (3, 4)]] + + matches = dict( + (get_image_ids_from_pair_id(pair_id), + blob_to_array(data, np.uint32, (-1, 2))) + for pair_id, data in db.execute("SELECT pair_id, data FROM matches")) + + assert np.all(matches[(1, 2)] == m12) + assert np.all(matches[(2, 3)] == m23) + assert np.all(matches[(3, 4)] == m34) + + # + # clean up + # + + db.close() + os.remove(args.database_path) + +#------------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("--database_path", type=str, default="database.db") + + args = parser.parse_args() + + main(args) diff --git a/internal/pycolmap/pycolmap/image.py b/internal/pycolmap/pycolmap/image.py new file mode 100644 index 0000000000000000000000000000000000000000..14efa32b0a91f116cbd7836b6480a60b10371196 --- /dev/null +++ b/internal/pycolmap/pycolmap/image.py @@ -0,0 +1,35 @@ +# Author: True Price + +import numpy as np + +#------------------------------------------------------------------------------- +# +# Image +# +#------------------------------------------------------------------------------- + +class Image: + def __init__(self, name_, camera_id_, q_, tvec_): + self.name = name_ + self.camera_id = camera_id_ + self.q = q_ + self.tvec = tvec_ + + self.points2D = np.empty((0, 2), dtype=np.float64) + self.point3D_ids = np.empty((0,), dtype=np.uint64) + + #--------------------------------------------------------------------------- + + def R(self): + return self.q.ToR() + + #--------------------------------------------------------------------------- + + def C(self): + return -self.R().T.dot(self.tvec) + + #--------------------------------------------------------------------------- + + @property + def t(self): + return self.tvec diff --git a/internal/pycolmap/pycolmap/rotation.py b/internal/pycolmap/pycolmap/rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..f0b4e811620e9668e8a44b1fd15e0574a7307f6a --- /dev/null +++ b/internal/pycolmap/pycolmap/rotation.py @@ -0,0 +1,324 @@ +# Author: True Price + +import numpy as np + +#------------------------------------------------------------------------------- +# +# Axis-Angle Functions +# +#------------------------------------------------------------------------------- + +# returns the cross product matrix representation of a 3-vector v +def cross_prod_matrix(v): + return np.array(((0., -v[2], v[1]), (v[2], 0., -v[0]), (-v[1], v[0], 0.))) + +#------------------------------------------------------------------------------- + +# www.euclideanspace.com/maths/geometry/rotations/conversions/angleToMatrix/ +# if angle is None, assume ||axis|| == angle, in radians +# if angle is not None, assume that axis is a unit vector +def axis_angle_to_rotation_matrix(axis, angle=None): + if angle is None: + angle = np.linalg.norm(axis) + if np.abs(angle) > np.finfo('float').eps: + axis = axis / angle + + cp_axis = cross_prod_matrix(axis) + return np.eye(3) + ( + np.sin(angle) * cp_axis + (1. - np.cos(angle)) * cp_axis.dot(cp_axis)) + +#------------------------------------------------------------------------------- + +# after some deliberation, I've decided the easiest way to do this is to use +# quaternions as an intermediary +def rotation_matrix_to_axis_angle(R): + return Quaternion.FromR(R).ToAxisAngle() + +#------------------------------------------------------------------------------- +# +# Quaternion +# +#------------------------------------------------------------------------------- + +class Quaternion: + # create a quaternion from an existing rotation matrix + # euclideanspace.com/maths/geometry/rotations/conversions/matrixToQuaternion/ + @staticmethod + def FromR(R): + trace = np.trace(R) + + if trace > 0: + qw = 0.5 * np.sqrt(1. + trace) + qx = (R[2,1] - R[1,2]) * 0.25 / qw + qy = (R[0,2] - R[2,0]) * 0.25 / qw + qz = (R[1,0] - R[0,1]) * 0.25 / qw + elif R[0,0] > R[1,1] and R[0,0] > R[2,2]: + s = 2. * np.sqrt(1. + R[0,0] - R[1,1] - R[2,2]) + qw = (R[2,1] - R[1,2]) / s + qx = 0.25 * s + qy = (R[0,1] + R[1,0]) / s + qz = (R[0,2] + R[2,0]) / s + elif R[1,1] > R[2,2]: + s = 2. * np.sqrt(1. + R[1,1] - R[0,0] - R[2,2]) + qw = (R[0,2] - R[2,0]) / s + qx = (R[0,1] + R[1,0]) / s + qy = 0.25 * s + qz = (R[1,2] + R[2,1]) / s + else: + s = 2. * np.sqrt(1. + R[2,2] - R[0,0] - R[1,1]) + qw = (R[1,0] - R[0,1]) / s + qx = (R[0,2] + R[2,0]) / s + qy = (R[1,2] + R[2,1]) / s + qz = 0.25 * s + + return Quaternion(np.array((qw, qx, qy, qz))) + + # if angle is None, assume ||axis|| == angle, in radians + # if angle is not None, assume that axis is a unit vector + @staticmethod + def FromAxisAngle(axis, angle=None): + if angle is None: + angle = np.linalg.norm(axis) + if np.abs(angle) > np.finfo('float').eps: + axis = axis / angle + + qw = np.cos(0.5 * angle) + axis = axis * np.sin(0.5 * angle) + + return Quaternion(np.array((qw, axis[0], axis[1], axis[2]))) + + #--------------------------------------------------------------------------- + + def __init__(self, q=np.array((1., 0., 0., 0.))): + if isinstance(q, Quaternion): + self.q = q.q.copy() + else: + q = np.asarray(q) + if q.size == 4: + self.q = q.copy() + elif q.size == 3: # convert from a 3-vector to a quaternion + self.q = np.empty(4) + self.q[0], self.q[1:] = 0., q.ravel() + else: + raise Exception('Input quaternion should be a 3- or 4-vector') + + def __add__(self, other): + return Quaternion(self.q + other.q) + + def __iadd__(self, other): + self.q += other.q + return self + + # conjugation via the ~ operator + def __invert__(self): + return Quaternion( + np.array((self.q[0], -self.q[1], -self.q[2], -self.q[3]))) + + # returns: self.q * other.q if other is a Quaternion; otherwise performs + # scalar multiplication + def __mul__(self, other): + if isinstance(other, Quaternion): # quaternion multiplication + return Quaternion(np.array(( + self.q[0] * other.q[0] - self.q[1] * other.q[1] - + self.q[2] * other.q[2] - self.q[3] * other.q[3], + self.q[0] * other.q[1] + self.q[1] * other.q[0] + + self.q[2] * other.q[3] - self.q[3] * other.q[2], + self.q[0] * other.q[2] - self.q[1] * other.q[3] + + self.q[2] * other.q[0] + self.q[3] * other.q[1], + self.q[0] * other.q[3] + self.q[1] * other.q[2] - + self.q[2] * other.q[1] + self.q[3] * other.q[0]))) + else: # scalar multiplication (assumed) + return Quaternion(other * self.q) + + def __rmul__(self, other): + return self * other + + def __imul__(self, other): + self.q[:] = (self * other).q + return self + + def __irmul__(self, other): + self.q[:] = (self * other).q + return self + + def __neg__(self): + return Quaternion(-self.q) + + def __sub__(self, other): + return Quaternion(self.q - other.q) + + def __isub__(self, other): + self.q -= other.q + return self + + def __str__(self): + return str(self.q) + + def copy(self): + return Quaternion(self) + + def dot(self, other): + return self.q.dot(other.q) + + # assume the quaternion is nonzero! + def inverse(self): + return Quaternion((~self).q / self.q.dot(self.q)) + + def norm(self): + return np.linalg.norm(self.q) + + def normalize(self): + self.q /= np.linalg.norm(self.q) + return self + + # assume x is a Nx3 numpy array or a numpy 3-vector + def rotate_points(self, x): + x = np.atleast_2d(x) + return x.dot(self.ToR().T) + + # convert to a rotation matrix + def ToR(self): + return np.eye(3) + 2 * np.array(( + (-self.q[2] * self.q[2] - self.q[3] * self.q[3], + self.q[1] * self.q[2] - self.q[3] * self.q[0], + self.q[1] * self.q[3] + self.q[2] * self.q[0]), + ( self.q[1] * self.q[2] + self.q[3] * self.q[0], + -self.q[1] * self.q[1] - self.q[3] * self.q[3], + self.q[2] * self.q[3] - self.q[1] * self.q[0]), + ( self.q[1] * self.q[3] - self.q[2] * self.q[0], + self.q[2] * self.q[3] + self.q[1] * self.q[0], + -self.q[1] * self.q[1] - self.q[2] * self.q[2]))) + + # convert to axis-angle representation, with angle encoded by the length + def ToAxisAngle(self): + # recall that for axis-angle representation (a, angle), with "a" unit: + # q = (cos(angle/2), a * sin(angle/2)) + # below, for readability, "theta" actually means half of the angle + + sin_sq_theta = self.q[1:].dot(self.q[1:]) + + # if theta is non-zero, then we can compute a unique rotation + if np.abs(sin_sq_theta) > np.finfo('float').eps: + sin_theta = np.sqrt(sin_sq_theta) + cos_theta = self.q[0] + + # atan2 is more stable, so we use it to compute theta + # note that we multiply by 2 to get the actual angle + angle = 2. * ( + np.arctan2(-sin_theta, -cos_theta) if cos_theta < 0. else + np.arctan2(sin_theta, cos_theta)) + + return self.q[1:] * (angle / sin_theta) + + # otherwise, the result is singular, and we avoid dividing by + # sin(angle/2) = 0 + return np.zeros(3) + + # euclideanspace.com/maths/geometry/rotations/conversions/quaternionToEuler + # this assumes the quaternion is non-zero + # returns yaw, pitch, roll, with application in that order + def ToEulerAngles(self): + qsq = self.q**2 + k = 2. * (self.q[0] * self.q[3] + self.q[1] * self.q[2]) / qsq.sum() + + if (1. - k) < np.finfo('float').eps: # north pole singularity + return 2. * np.arctan2(self.q[1], self.q[0]), 0.5 * np.pi, 0. + if (1. + k) < np.finfo('float').eps: # south pole singularity + return -2. * np.arctan2(self.q[1], self.q[0]), -0.5 * np.pi, 0. + + yaw = np.arctan2(2. * (self.q[0] * self.q[2] - self.q[1] * self.q[3]), + qsq[0] + qsq[1] - qsq[2] - qsq[3]) + pitch = np.arcsin(k) + roll = np.arctan2(2. * (self.q[0] * self.q[1] - self.q[2] * self.q[3]), + qsq[0] - qsq[1] + qsq[2] - qsq[3]) + + return yaw, pitch, roll + +#------------------------------------------------------------------------------- +# +# DualQuaternion +# +#------------------------------------------------------------------------------- + +class DualQuaternion: + # DualQuaternion from an existing rotation + translation + @staticmethod + def FromQT(q, t): + return DualQuaternion(qe=(0.5 * np.asarray(t))) * DualQuaternion(q) + + def __init__(self, q0=np.array((1., 0., 0., 0.)), qe=np.zeros(4)): + self.q0, self.qe = Quaternion(q0), Quaternion(qe) + + def __add__(self, other): + return DualQuaternion(self.q0 + other.q0, self.qe + other.qe) + + def __iadd__(self, other): + self.q0 += other.q0 + self.qe += other.qe + return self + + # conguation via the ~ operator + def __invert__(self): + return DualQuaternion(~self.q0, ~self.qe) + + def __mul__(self, other): + if isinstance(other, DualQuaternion): + return DualQuaternion( + self.q0 * other.q0, + self.q0 * other.qe + self.qe * other.q0) + elif isinstance(other, complex): # multiplication by a dual number + return DualQuaternion( + self.q0 * other.real, + self.q0 * other.imag + self.qe * other.real) + else: # scalar multiplication (assumed) + return DualQuaternion(other * self.q0, other * self.qe) + + def __rmul__(self, other): + return self.__mul__(other) + + def __imul__(self, other): + tmp = self * other + self.q0, self.qe = tmp.q0, tmp.qe + return self + + def __neg__(self): + return DualQuaternion(-self.q0, -self.qe) + + def __sub__(self, other): + return DualQuaternion(self.q0 - other.q0, self.qe - other.qe) + + def __isub__(self, other): + self.q0 -= other.q0 + self.qe -= other.qe + return self + + # q^-1 = q* / ||q||^2 + # assume that q0 is nonzero! + def inverse(self): + normsq = complex(q0.dot(q0), 2. * self.q0.q.dot(self.qe.q)) + inv_len_real = 1. / normsq.real + return ~self * complex( + inv_len_real, -normsq.imag * inv_len_real * inv_len_real) + + # returns a complex representation of the real and imaginary parts of the norm + # assume that q0 is nonzero! + def norm(self): + q0_norm = self.q0.norm() + return complex(q0_norm, self.q0.dot(self.qe) / q0_norm) + + # assume that q0 is nonzero! + def normalize(self): + # current length is ||q0|| + eps * ( / ||q0||) + # writing this as a + eps * b, the inverse is + # 1/||q|| = 1/a - eps * b / a^2 + norm = self.norm() + inv_len_real = 1. / norm.real + self *= complex(inv_len_real, -norm.imag * inv_len_real * inv_len_real) + return self + + # return the translation vector for this dual quaternion + def getT(self): + return 2 * (self.qe * ~self.q0).q[1:] + + def ToQT(self): + return self.q0, self.getT() diff --git a/internal/pycolmap/pycolmap/scene_manager.py b/internal/pycolmap/pycolmap/scene_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..bc690ebefb23ec0972e3f44a0c299219a53abd3b --- /dev/null +++ b/internal/pycolmap/pycolmap/scene_manager.py @@ -0,0 +1,670 @@ +# Author: True Price + +import array +import numpy as np +import os +import struct + +from collections import OrderedDict +from itertools import combinations + +from camera import Camera +from image import Image +from rotation import Quaternion + +#------------------------------------------------------------------------------- +# +# SceneManager +# +#------------------------------------------------------------------------------- + +class SceneManager: + INVALID_POINT3D = np.uint64(-1) + + def __init__(self, colmap_results_folder, image_path=None): + self.folder = colmap_results_folder + if not self.folder.endswith('/'): + self.folder += '/' + + self.image_path = None + self.load_colmap_project_file(image_path=image_path) + + self.cameras = OrderedDict() + self.images = OrderedDict() + self.name_to_image_id = dict() + + self.last_camera_id = 0 + self.last_image_id = 0 + + # Nx3 array of point3D xyz's + self.points3D = np.zeros((0, 3)) + + # for each element in points3D, stores the id of the point + self.point3D_ids = np.empty(0) + + # point3D_id => index in self.points3D + self.point3D_id_to_point3D_idx = dict() + + # point3D_id => [(image_id, point2D idx in image)] + self.point3D_id_to_images = dict() + + self.point3D_colors = np.zeros((0, 3), dtype=np.uint8) + self.point3D_errors = np.zeros(0) + + #--------------------------------------------------------------------------- + + def load_colmap_project_file(self, project_file=None, image_path=None): + if project_file is None: + project_file = self.folder + 'project.ini' + + self.image_path = image_path + + if self.image_path is None: + try: + with open(project_file, 'r') as f: + for line in iter(f.readline, ''): + if line.startswith('image_path'): + self.image_path = line[11:].strip() + break + except: + pass + + if self.image_path is None: + print('Warning: image_path not found for reconstruction') + elif not self.image_path.endswith('/'): + self.image_path += '/' + + #--------------------------------------------------------------------------- + + def load(self): + self.load_cameras() + self.load_images() + self.load_points3D() + + #--------------------------------------------------------------------------- + + def load_cameras(self, input_file=None): + if input_file is None: + input_file = self.folder + 'cameras.bin' + if os.path.exists(input_file): + self._load_cameras_bin(input_file) + else: + input_file = self.folder + 'cameras.txt' + if os.path.exists(input_file): + self._load_cameras_txt(input_file) + else: + raise IOError('no cameras file found') + + def _load_cameras_bin(self, input_file): + self.cameras = OrderedDict() + + with open(input_file, 'rb') as f: + num_cameras = struct.unpack('L', f.read(8))[0] + + for _ in range(num_cameras): + camera_id, camera_type, w, h = struct.unpack('IiLL', f.read(24)) + num_params = Camera.GetNumParams(camera_type) + params = struct.unpack('d' * num_params, f.read(8 * num_params)) + self.cameras[camera_id] = Camera(camera_type, w, h, params) + self.last_camera_id = max(self.last_camera_id, camera_id) + + def _load_cameras_txt(self, input_file): + self.cameras = OrderedDict() + + with open(input_file, 'r') as f: + for line in iter(lambda: f.readline().strip(), ''): + if not line or line.startswith('#'): + continue + + data = line.split() + camera_id = int(data[0]) + self.cameras[camera_id] = Camera( + data[1], int(data[2]), int(data[3]), map(float, data[4:])) + self.last_camera_id = max(self.last_camera_id, camera_id) + + #--------------------------------------------------------------------------- + + def load_images(self, input_file=None): + if input_file is None: + input_file = self.folder + 'images.bin' + if os.path.exists(input_file): + self._load_images_bin(input_file) + else: + input_file = self.folder + 'images.txt' + if os.path.exists(input_file): + self._load_images_txt(input_file) + else: + raise IOError('no images file found') + + def _load_images_bin(self, input_file): + self.images = OrderedDict() + + with open(input_file, 'rb') as f: + num_images = struct.unpack('L', f.read(8))[0] + image_struct = struct.Struct('7x improvements in 60 image model, 23s -> 3s. + points_array = array.array('d') + points_array.fromfile(f, 3 * num_points2D) + points_elements = np.array(points_array).reshape((num_points2D, 3)) + image.points2D = points_elements[:, :2] + + ids_array = array.array('Q') + ids_array.frombytes(points_elements[:, 2].tobytes()) + image.point3D_ids = np.array(ids_array, dtype=np.uint64).reshape( + (num_points2D,)) + + # automatically remove points without an associated 3D point + #mask = (image.point3D_ids != SceneManager.INVALID_POINT3D) + #image.points2D = image.points2D[mask] + #image.point3D_ids = image.point3D_ids[mask] + + self.images[image_id] = image + self.name_to_image_id[image.name] = image_id + + self.last_image_id = max(self.last_image_id, image_id) + + def _load_images_txt(self, input_file): + self.images = OrderedDict() + + with open(input_file, 'r') as f: + is_camera_description_line = False + + for line in iter(lambda: f.readline().strip(), ''): + if not line or line.startswith('#'): + continue + + is_camera_description_line = not is_camera_description_line + + data = line.split() + + if is_camera_description_line: + image_id = int(data[0]) + image = Image(data[-1], int(data[-2]), + Quaternion(np.array(map(float, data[1:5]))), + np.array(map(float, data[5:8]))) + else: + image.points2D = np.array( + [map(float, data[::3]), map(float, data[1::3])]).T + image.point3D_ids = np.array(map(np.uint64, data[2::3])) + + # automatically remove points without an associated 3D point + #mask = (image.point3D_ids != SceneManager.INVALID_POINT3D) + #image.points2D = image.points2D[mask] + #image.point3D_ids = image.point3D_ids[mask] + + self.images[image_id] = image + self.name_to_image_id[image.name] = image_id + + self.last_image_id = max(self.last_image_id, image_id) + + #--------------------------------------------------------------------------- + + def load_points3D(self, input_file=None): + if input_file is None: + input_file = self.folder + 'points3D.bin' + if os.path.exists(input_file): + self._load_points3D_bin(input_file) + else: + input_file = self.folder + 'points3D.txt' + if os.path.exists(input_file): + self._load_points3D_txt(input_file) + else: + raise IOError('no points3D file found') + + def _load_points3D_bin(self, input_file): + with open(input_file, 'rb') as f: + num_points3D = struct.unpack('L', f.read(8))[0] + + self.points3D = np.empty((num_points3D, 3)) + self.point3D_ids = np.empty(num_points3D, dtype=np.uint64) + self.point3D_colors = np.empty((num_points3D, 3), dtype=np.uint8) + self.point3D_id_to_point3D_idx = dict() + self.point3D_id_to_images = dict() + self.point3D_errors = np.empty(num_points3D) + + data_struct = struct.Struct('>fid, '# Camera list with one line of data per camera:' + print>>fid, '# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]' + print>>fid, '# Number of cameras:', len(self.cameras) + + for camera_id, camera in sorted(self.cameras.iteritems()): + print>>fid, camera_id, camera + + #--------------------------------------------------------------------------- + + def save_images(self, output_folder, output_file=None, binary=True): + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + if output_file is None: + output_file = 'images.bin' if binary else 'images.txt' + + output_file = os.path.join(output_folder, output_file) + + if binary: + self._save_images_bin(output_file) + else: + self._save_images_txt(output_file) + + def _save_images_bin(self, output_file): + with open(output_file, 'wb') as fid: + fid.write(struct.pack('L', len(self.images))) + + for image_id, image in self.images.iteritems(): + fid.write(struct.pack('I', image_id)) + fid.write(image.q.q.tobytes()) + fid.write(image.tvec.tobytes()) + fid.write(struct.pack('I', image.camera_id)) + fid.write(image.name + '\0') + fid.write(struct.pack('L', len(image.points2D))) + data = np.rec.fromarrays( + (image.points2D[:,0], image.points2D[:,1], image.point3D_ids)) + fid.write(data.tobytes()) + + def _save_images_txt(self, output_file): + with open(output_file, 'w') as fid: + print>>fid, '# Image list with two lines of data per image:' + print>>fid, '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME' + print>>fid, '# POINTS2D[] as (X, Y, POINT3D_ID)' + print>>fid, '# Number of images: {},'.format(len(self.images)), + print>>fid, 'mean observations per image: unknown' + + for image_id, image in self.images.iteritems(): + print>>fid, image_id, + print>>fid, ' '.join(str(qi) for qi in image.q.q), + print>>fid, ' '.join(str(ti) for ti in image.tvec), + print>>fid, image.camera_id, image.name + + data = np.rec.fromarrays( + (image.points2D[:,0], image.points2D[:,1], + image.point3D_ids.astype(np.int64))) + if len(data) > 0: + np.savetxt(fid, data, '%.2f %.2f %d', newline=' ') + fid.seek(-1, os.SEEK_CUR) + fid.write('\n') + + #--------------------------------------------------------------------------- + + def save_points3D(self, output_folder, output_file=None, binary=True): + if not os.path.exists(output_folder): + os.makedirs(output_folder) + + if output_file is None: + output_file = 'points3D.bin' if binary else 'points3D.txt' + + output_file = os.path.join(output_folder, output_file) + + if binary: + self._save_points3D_bin(output_file) + else: + self._save_points3D_txt(output_file) + + def _save_points3D_bin(self, output_file): + num_valid_points3D = sum( + 1 for point3D_idx in self.point3D_id_to_point3D_idx.itervalues() + if point3D_idx != SceneManager.INVALID_POINT3D) + + iter_point3D_id_to_point3D_idx = \ + self.point3D_id_to_point3D_idx.iteritems() + + with open(output_file, 'wb') as fid: + fid.write(struct.pack('L', num_valid_points3D)) + + for point3D_id, point3D_idx in iter_point3D_id_to_point3D_idx: + if point3D_idx == SceneManager.INVALID_POINT3D: + continue + + fid.write(struct.pack('L', point3D_id)) + fid.write(self.points3D[point3D_idx].tobytes()) + fid.write(self.point3D_colors[point3D_idx].tobytes()) + fid.write(self.point3D_errors[point3D_idx].tobytes()) + fid.write( + struct.pack('L', len(self.point3D_id_to_images[point3D_id]))) + fid.write(self.point3D_id_to_images[point3D_id].tobytes()) + + def _save_points3D_txt(self, output_file): + num_valid_points3D = sum( + 1 for point3D_idx in self.point3D_id_to_point3D_idx.itervalues() + if point3D_idx != SceneManager.INVALID_POINT3D) + + array_to_string = lambda arr: ' '.join(str(x) for x in arr) + + iter_point3D_id_to_point3D_idx = \ + self.point3D_id_to_point3D_idx.iteritems() + + with open(output_file, 'w') as fid: + print>>fid, '# 3D point list with one line of data per point:' + print>>fid, '# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as ', + print>>fid, '(IMAGE_ID, POINT2D_IDX)' + print>>fid, '# Number of points: {},'.format(num_valid_points3D), + print>>fid, 'mean track length: unknown' + + for point3D_id, point3D_idx in iter_point3D_id_to_point3D_idx: + if point3D_idx == SceneManager.INVALID_POINT3D: + continue + + print>>fid, point3D_id, + print>>fid, array_to_string(self.points3D[point3D_idx]), + print>>fid, array_to_string(self.point3D_colors[point3D_idx]), + print>>fid, self.point3D_errors[point3D_idx], + print>>fid, array_to_string( + self.point3D_id_to_images[point3D_id].flat) + + #--------------------------------------------------------------------------- + + # return the image id associated with a given image file + def get_image_from_name(self, image_name): + image_id = self.name_to_image_id[image_name] + return image_id, self.images[image_id] + + #--------------------------------------------------------------------------- + + def get_camera(self, camera_id): + return self.cameras[camera_id] + + #--------------------------------------------------------------------------- + + def get_points3D(self, image_id, return_points2D=True, return_colors=False): + image = self.images[image_id] + + mask = (image.point3D_ids != SceneManager.INVALID_POINT3D) + + point3D_idxs = np.array([ + self.point3D_id_to_point3D_idx[point3D_id] + for point3D_id in image.point3D_ids[mask]]) + # detect filtered points + filter_mask = (point3D_idxs != SceneManager.INVALID_POINT3D) + point3D_idxs = point3D_idxs[filter_mask] + result = [self.points3D[point3D_idxs,:]] + + if return_points2D: + mask[mask] &= filter_mask + result += [image.points2D[mask]] + if return_colors: + result += [self.point3D_colors[point3D_idxs,:]] + + return result if len(result) > 1 else result[0] + + #--------------------------------------------------------------------------- + + def point3D_valid(self, point3D_id): + return (self.point3D_id_to_point3D_idx[point3D_id] != + SceneManager.INVALID_POINT3D) + + #--------------------------------------------------------------------------- + + def get_filtered_points3D(self, return_colors=False): + point3D_idxs = [ + idx for idx in self.point3D_id_to_point3D_idx.values() + if idx != SceneManager.INVALID_POINT3D] + result = [self.points3D[point3D_idxs,:]] + + if return_colors: + result += [self.point3D_colors[point3D_idxs,:]] + + return result if len(result) > 1 else result[0] + + #--------------------------------------------------------------------------- + + # return 3D points shared by two images + def get_shared_points3D(self, image_id1, image_id2): + point3D_ids = ( + set(self.images[image_id1].point3D_ids) & + set(self.images[image_id2].point3D_ids)) + point3D_ids.discard(SceneManager.INVALID_POINT3D) + + point3D_idxs = np.array([self.point3D_id_to_point3D_idx[point3D_id] + for point3D_id in point3D_ids]) + + return self.points3D[point3D_idxs,:] + + #--------------------------------------------------------------------------- + + # project *all* 3D points into image, return their projection coordinates, + # as well as their 3D positions + def get_viewed_points(self, image_id): + image = self.images[image_id] + + # get unfiltered points + point3D_idxs = set(self.point3D_id_to_point3D_idx.itervalues()) + point3D_idxs.discard(SceneManager.INVALID_POINT3D) + point3D_idxs = list(point3D_idxs) + points3D = self.points3D[point3D_idxs,:] + + # orient points relative to camera + R = image.q.ToR() + points3D = points3D.dot(R.T) + image.tvec[np.newaxis,:] + points3D = points3D[points3D[:,2] > 0,:] # keep points with positive z + + # put points into image coordinates + camera = self.cameras[image.camera_id] + points2D = points3D.dot(camera.get_camera_matrix().T) + points2D = points2D[:,:2] / points2D[:,2][:,np.newaxis] + + # keep points that are within the image + mask = ( + (points2D[:,0] >= 0) & + (points2D[:,1] >= 0) & + (points2D[:,0] < camera.width - 1) & + (points2D[:,1] < camera.height - 1)) + + return points2D[mask,:], points3D[mask,:] + + #--------------------------------------------------------------------------- + + def add_camera(self, camera): + self.last_camera_id += 1 + self.cameras[self.last_camera_id] = camera + return self.last_camera_id + + #--------------------------------------------------------------------------- + + def add_image(self, image): + self.last_image_id += 1 + self.images[self.last_image_id] = image + return self.last_image_id + + #--------------------------------------------------------------------------- + + def delete_images(self, image_list): + # delete specified images + for image_id in image_list: + if image_id in self.images: + del self.images[image_id] + + keep_set = set(self.images.iterkeys()) + + # delete references to specified images, and ignore any points that are + # invalidated + iter_point3D_id_to_point3D_idx = \ + self.point3D_id_to_point3D_idx.iteritems() + + for point3D_id, point3D_idx in iter_point3D_id_to_point3D_idx: + if point3D_idx == SceneManager.INVALID_POINT3D: + continue + + mask = np.array([ + image_id in keep_set + for image_id in self.point3D_id_to_images[point3D_id][:,0]]) + if np.any(mask): + self.point3D_id_to_images[point3D_id] = \ + self.point3D_id_to_images[point3D_id][mask] + else: + self.point3D_id_to_point3D_idx[point3D_id] = \ + SceneManager.INVALID_POINT3D + + #--------------------------------------------------------------------------- + + # camera_list: set of cameras whose points we'd like to keep + # min/max triangulation angle: in degrees + def filter_points3D(self, + min_track_len=0, max_error=np.inf, min_tri_angle=0, + max_tri_angle=180, image_set=set()): + + image_set = set(image_set) + + check_triangulation_angles = (min_tri_angle > 0 or max_tri_angle < 180) + if check_triangulation_angles: + max_tri_prod = np.cos(np.radians(min_tri_angle)) + min_tri_prod = np.cos(np.radians(max_tri_angle)) + + iter_point3D_id_to_point3D_idx = \ + self.point3D_id_to_point3D_idx.iteritems() + + image_ids = [] + + for point3D_id, point3D_idx in iter_point3D_id_to_point3D_idx: + if point3D_idx == SceneManager.INVALID_POINT3D: + continue + + if image_set or min_track_len > 0: + image_ids = set(self.point3D_id_to_images[point3D_id][:,0]) + + # check if error and min track length are sufficient, or if none of + # the selected cameras see the point + if (len(image_ids) < min_track_len or + self.point3D_errors[point3D_idx] > max_error or + image_set and image_set.isdisjoint(image_ids)): + self.point3D_id_to_point3D_idx[point3D_id] = \ + SceneManager.INVALID_POINT3D + + # find dot product between all camera viewing rays + elif check_triangulation_angles: + xyz = self.points3D[point3D_idx,:] + tvecs = np.array( + [(self.images[image_id].tvec - xyz) + for image_id in image_ids]) + tvecs /= np.linalg.norm(tvecs, axis=-1)[:,np.newaxis] + + cos_theta = np.array( + [u.dot(v) for u,v in combinations(tvecs, 2)]) + + # min_prod = cos(maximum viewing angle), and vice versa + # if maximum viewing angle is too small or too large, + # don't add this point + if (np.min(cos_theta) > max_tri_prod or + np.max(cos_theta) < min_tri_prod): + self.point3D_id_to_point3D_idx[point3D_id] = \ + SceneManager.INVALID_POINT3D + + # apply the filters to the image point3D_ids + for image in self.images.itervalues(): + mask = np.array([ + self.point3D_id_to_point3D_idx.get(point3D_id, 0) \ + == SceneManager.INVALID_POINT3D + for point3D_id in image.point3D_ids]) + image.point3D_ids[mask] = SceneManager.INVALID_POINT3D + + #--------------------------------------------------------------------------- + + # scene graph: {image_id: [image_id: #shared points]} + def build_scene_graph(self): + self.scene_graph = defaultdict(lambda: defaultdict(int)) + point3D_iter = self.point3D_id_to_images.iteritems() + + for i, (point3D_id, images) in enumerate(point3D_iter): + if not self.point3D_valid(point3D_id): + continue + + for image_id1, image_id2 in combinations(images[:,0], 2): + self.scene_graph[image_id1][image_id2] += 1 + self.scene_graph[image_id2][image_id1] += 1 diff --git a/internal/pycolmap/tools/colmap_to_nvm.py b/internal/pycolmap/tools/colmap_to_nvm.py new file mode 100644 index 0000000000000000000000000000000000000000..38a4ba8ed4b51523e1525abf57331993301adfda --- /dev/null +++ b/internal/pycolmap/tools/colmap_to_nvm.py @@ -0,0 +1,69 @@ +import itertools +import sys +sys.path.append("..") + +import numpy as np + +from pycolmap import Quaternion, SceneManager + + +#------------------------------------------------------------------------------- + +def main(args): + scene_manager = SceneManager(args.input_folder) + scene_manager.load() + + with open(args.output_file, "w") as fid: + fid.write("NVM_V3\n \n{:d}\n".format(len(scene_manager.images))) + + image_fmt_str = " {:.3f} " + 7 * "{:.7f} " + for image_id, image in scene_manager.images.iteritems(): + camera = scene_manager.cameras[image.camera_id] + f = 0.5 * (camera.fx + camera.fy) + fid.write(args.image_name_prefix + image.name) + fid.write(image_fmt_str.format( + *((f,) + tuple(image.q.q) + tuple(image.C())))) + if camera.distortion_func is None: + fid.write("0 0\n") + else: + fid.write("{:.7f} 0\n".format(-camera.k1)) + + image_id_to_idx = dict( + (image_id, i) for i, image_id in enumerate(scene_manager.images)) + + fid.write("{:d}\n".format(len(scene_manager.points3D))) + for i, point3D_id in enumerate(scene_manager.point3D_ids): + fid.write( + "{:.7f} {:.7f} {:.7f} ".format(*scene_manager.points3D[i])) + fid.write( + "{:d} {:d} {:d} ".format(*scene_manager.point3D_colors[i])) + keypoints = [ + (image_id_to_idx[image_id], kp_idx) + + tuple(scene_manager.images[image_id].points2D[kp_idx]) + for image_id, kp_idx in + scene_manager.point3D_id_to_images[point3D_id]] + fid.write("{:d}".format(len(keypoints))) + fid.write( + (len(keypoints) * " {:d} {:d} {:.3f} {:.3f}" + "\n").format( + *itertools.chain(*keypoints))) + + +#------------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Save a COLMAP reconstruction in the NVM format " + "(http://ccwu.me/vsfm/doc.html#nvm).", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("input_folder") + parser.add_argument("output_file") + + parser.add_argument("--image_name_prefix", type=str, default="", + help="prefix image names with this string (e.g., 'images/')") + + args = parser.parse_args() + + main(args) diff --git a/internal/pycolmap/tools/delete_images.py b/internal/pycolmap/tools/delete_images.py new file mode 100644 index 0000000000000000000000000000000000000000..f17a84a8a2fa842283c032eeffbce4c5a8de37db --- /dev/null +++ b/internal/pycolmap/tools/delete_images.py @@ -0,0 +1,36 @@ +import sys +sys.path.append("..") + +import numpy as np + +from pycolmap import DualQuaternion, Image, SceneManager + + +#------------------------------------------------------------------------------- + +def main(args): + scene_manager = SceneManager(args.input_folder) + scene_manager.load() + + image_ids = map(scene_manager.get_image_from_name, + iter(lambda: sys.stdin.readline().strip(), "")) + scene_manager.delete_images(image_ids) + + scene_manager.save(args.output_folder) + + +#------------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Deletes images (filenames read from stdin) from a model.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("input_folder") + parser.add_argument("output_folder") + + args = parser.parse_args() + + main(args) diff --git a/internal/pycolmap/tools/impute_missing_cameras.py b/internal/pycolmap/tools/impute_missing_cameras.py new file mode 100644 index 0000000000000000000000000000000000000000..7ff8f322b7919f24a9b8652221dd09a0f0b16d5b --- /dev/null +++ b/internal/pycolmap/tools/impute_missing_cameras.py @@ -0,0 +1,180 @@ +import sys +sys.path.append("..") + +import numpy as np + +from pycolmap import DualQuaternion, Image, SceneManager + + +#------------------------------------------------------------------------------- + +image_to_idx = lambda im: int(im.name[:im.name.rfind(".")]) + + +#------------------------------------------------------------------------------- + +def interpolate_linear(images, camera_id, file_format): + if len(images) < 2: + raise ValueError("Need at least two images for linear interpolation!") + + prev_image = images[0] + prev_idx = image_to_idx(prev_image) + prev_dq = DualQuaternion.FromQT(prev_image.q, prev_image.t) + start = prev_idx + + new_images = [] + + for image in images[1:]: + curr_idx = image_to_idx(image) + curr_dq = DualQuaternion.FromQT(image.q, image.t) + T = curr_idx - prev_idx + Tinv = 1. / T + + # like quaternions, dq(x) = -dq(x), so we'll need to pick the one more + # appropriate for interpolation by taking -dq if the dot product of the + # two q-vectors is negative + if prev_dq.q0.dot(curr_dq.q0) < 0: + curr_dq = -curr_dq + + for i in xrange(1, T): + t = i * Tinv + dq = t * prev_dq + (1. - t) * curr_dq + q, t = dq.ToQT() + new_images.append( + Image(file_format.format(prev_idx + i), args.camera_id, q, t)) + + prev_idx = curr_idx + prev_dq = curr_dq + + return new_images + + +#------------------------------------------------------------------------------- + +def interpolate_hermite(images, camera_id, file_format): + if len(images) < 4: + raise ValueError( + "Need at least four images for Hermite spline interpolation!") + + new_images = [] + + # linear blending for the first frames + T0 = image_to_idx(images[0]) + dq0 = DualQuaternion.FromQT(images[0].q, images[0].t) + T1 = image_to_idx(images[1]) + dq1 = DualQuaternion.FromQT(images[1].q, images[1].t) + + if dq0.q0.dot(dq1.q0) < 0: + dq1 = -dq1 + dT = 1. / float(T1 - T0) + for j in xrange(1, T1 - T0): + t = j * dT + dq = ((1. - t) * dq0 + t * dq1).normalize() + new_images.append( + Image(file_format.format(T0 + j), camera_id, *dq.ToQT())) + + T2 = image_to_idx(images[2]) + dq2 = DualQuaternion.FromQT(images[2].q, images[2].t) + if dq1.q0.dot(dq2.q0) < 0: + dq2 = -dq2 + + # Hermite spline interpolation of dual quaternions + # pdfs.semanticscholar.org/05b1/8ede7f46c29c2722fed3376d277a1d286c55.pdf + for i in xrange(1, len(images) - 2): + T3 = image_to_idx(images[i + 2]) + dq3 = DualQuaternion.FromQT(images[i + 2].q, images[i + 2].t) + if dq2.q0.dot(dq3.q0) < 0: + dq3 = -dq3 + + prev_duration = T1 - T0 + current_duration = T2 - T1 + next_duration = T3 - T2 + + # approximate the derivatives at dq1 and dq2 using weighted central + # differences + dt1 = 1. / float(T2 - T0) + dt2 = 1. / float(T3 - T1) + + m1 = (current_duration * dt1) * (dq2 - dq1) + \ + (prev_duration * dt1) * (dq1 - dq0) + m2 = (next_duration * dt2) * (dq3 - dq2) + \ + (current_duration * dt2) * (dq2 - dq1) + + dT = 1. / float(current_duration) + + for j in xrange(1, current_duration): + t = j * dT # 0 to 1 + t2 = t * t # t squared + t3 = t2 * t # t cubed + + # coefficients of the Hermite spline (a=>dq and b=>m) + a1 = 2. * t3 - 3. * t2 + 1. + b1 = t3 - 2. * t2 + t + a2 = -2. * t3 + 3. * t2 + b2 = t3 - t2 + + dq = (a1 * dq1 + b1 * m1 + a2 * dq2 + b2 * m2).normalize() + + new_images.append( + Image(file_format.format(T1 + j), camera_id, *dq.ToQT())) + + T0, T1, T2 = T1, T2, T3 + dq0, dq1, dq2 = dq1, dq2, dq3 + + # linear blending for the last frames + dT = 1. / float(T2 - T1) + for j in xrange(1, T2 - T1): + t = j * dT # 0 to 1 + dq = ((1. - t) * dq1 + t * dq2).normalize() + new_images.append( + Image(file_format.format(T1 + j), camera_id, *dq.ToQT())) + + return new_images + + +#------------------------------------------------------------------------------- + +def main(args): + scene_manager = SceneManager(args.input_folder) + scene_manager.load() + + images = sorted(scene_manager.images.itervalues(), key=image_to_idx) + + if args.method.lower() == "linear": + new_images = interpolate_linear(images, args.camera_id, args.format) + else: + new_images = interpolate_hermite(images, args.camera_id, args.format) + + map(scene_manager.add_image, new_images) + + scene_manager.save(args.output_folder) + + +#------------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Given a reconstruction with ordered images *with integer " + "filenames* like '000100.png', fill in missing camera positions for " + "intermediate frames.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("input_folder") + parser.add_argument("output_folder") + + parser.add_argument("--camera_id", type=int, default=1, + help="camera id to use for the missing images") + + parser.add_argument("--format", type=str, default="{:06d}.png", + help="filename format to use for added images") + + parser.add_argument( + "--method", type=str.lower, choices=("linear", "hermite"), + default="hermite", + help="Pose imputation method") + + args = parser.parse_args() + + main(args) diff --git a/internal/pycolmap/tools/save_cameras_as_ply.py b/internal/pycolmap/tools/save_cameras_as_ply.py new file mode 100644 index 0000000000000000000000000000000000000000..6ec89506f61e8fbc12e026853c1a1d663a21d658 --- /dev/null +++ b/internal/pycolmap/tools/save_cameras_as_ply.py @@ -0,0 +1,92 @@ +import sys +sys.path.append("..") + +import numpy as np +import os + +from pycolmap import SceneManager + + +#------------------------------------------------------------------------------- + +# Saves the cameras as a mesh +# +# inputs: +# - ply_file: output file +# - images: ordered array of pycolmap Image objects +# - color: color string for the camera +# - scale: amount to shrink/grow the camera model +def save_camera_ply(ply_file, images, scale): + points3D = scale * np.array(( + (0., 0., 0.), + (-1., -1., 1.), + (-1., 1., 1.), + (1., -1., 1.), + (1., 1., 1.))) + + faces = np.array(((0, 2, 1), + (0, 4, 2), + (0, 3, 4), + (0, 1, 3), + (1, 2, 4), + (1, 4, 3))) + + r = np.linspace(0, 255, len(images), dtype=np.uint8) + g = 255 - r + b = r - np.linspace(0, 128, len(images), dtype=np.uint8) + color = np.column_stack((r, g, b)) + + with open(ply_file, "w") as fid: + print>>fid, "ply" + print>>fid, "format ascii 1.0" + print>>fid, "element vertex", len(points3D) * len(images) + print>>fid, "property float x" + print>>fid, "property float y" + print>>fid, "property float z" + print>>fid, "property uchar red" + print>>fid, "property uchar green" + print>>fid, "property uchar blue" + print>>fid, "element face", len(faces) * len(images) + print>>fid, "property list uchar int vertex_index" + print>>fid, "end_header" + + for image, c in zip(images, color): + for p3D in (points3D.dot(image.R()) + image.C()): + print>>fid, p3D[0], p3D[1], p3D[2], c[0], c[1], c[2] + + for i in xrange(len(images)): + for f in (faces + len(points3D) * i): + print>>fid, "3 {} {} {}".format(*f) + + +#------------------------------------------------------------------------------- + +def main(args): + scene_manager = SceneManager(args.input_folder) + scene_manager.load_images() + + images = sorted(scene_manager.images.itervalues(), + key=lambda image: image.name) + + save_camera_ply(args.output_file, images, args.scale) + + +#------------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Saves camera positions to a PLY for easy viewing outside " + "of COLMAP. Currently, camera FoV is not reflected in the output.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("input_folder") + parser.add_argument("output_file") + + parser.add_argument("--scale", type=float, default=1., + help="Scaling factor for the camera mesh.") + + args = parser.parse_args() + + main(args) diff --git a/internal/pycolmap/tools/transform_model.py b/internal/pycolmap/tools/transform_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b22f1b3699a8ec029715324548f8dc823a4f59 --- /dev/null +++ b/internal/pycolmap/tools/transform_model.py @@ -0,0 +1,48 @@ +import sys +sys.path.append("..") + +import numpy as np + +from pycolmap import Quaternion, SceneManager + + +#------------------------------------------------------------------------------- + +def main(args): + scene_manager = SceneManager(args.input_folder) + scene_manager.load() + + # expect each line of input corresponds to one row + P = np.array([ + map(float, sys.stdin.readline().strip().split()) for _ in xrange(3)]) + + scene_manager.points3D[:] = scene_manager.points3D.dot(P[:,:3].T) + P[:,3] + + # get rotation without any global scaling (assuming isotropic scaling) + scale = np.cbrt(np.linalg.det(P[:,:3])) + q_old_from_new = ~Quaternion.FromR(P[:,:3] / scale) + + for image in scene_manager.images.itervalues(): + image.q *= q_old_from_new + image.tvec = scale * image.tvec - image.R().dot(P[:,3]) + + scene_manager.save(args.output_folder) + + +#------------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Apply a 3x4 transformation matrix to a COLMAP model and " + "save the result as a new model. Row-major input can be piped in from " + "a file or entered via the command line.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("input_folder") + parser.add_argument("output_folder") + + args = parser.parse_args() + + main(args) diff --git a/internal/pycolmap/tools/write_camera_track_to_bundler.py b/internal/pycolmap/tools/write_camera_track_to_bundler.py new file mode 100644 index 0000000000000000000000000000000000000000..66fc91ab4ee84eb01e3445fc407c58b9acda5be7 --- /dev/null +++ b/internal/pycolmap/tools/write_camera_track_to_bundler.py @@ -0,0 +1,60 @@ +import sys +sys.path.append("..") + +import numpy as np + +from pycolmap import SceneManager + + +#------------------------------------------------------------------------------- + +def main(args): + scene_manager = SceneManager(args.input_folder) + scene_manager.load_cameras() + scene_manager.load_images() + + if args.sort: + images = sorted( + scene_manager.images.itervalues(), key=lambda im: im.name) + else: + images = scene_manager.images.values() + + fid = open(args.output_file, "w") + fid_filenames = open(args.output_file + ".list.txt", "w") + + print>>fid, "# Bundle file v0.3" + print>>fid, len(images), 0 + + for image in images: + print>>fid_filenames, image.name + camera = scene_manager.cameras[image.camera_id] + print>>fid, 0.5 * (camera.fx + camera.fy), 0, 0 + R, t = image.R(), image.t + print>>fid, R[0, 0], R[0, 1], R[0, 2] + print>>fid, -R[1, 0], -R[1, 1], -R[1, 2] + print>>fid, -R[2, 0], -R[2, 1], -R[2, 2] + print>>fid, t[0], -t[1], -t[2] + + fid.close() + fid_filenames.close() + + +#------------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Saves the camera positions in the Bundler format. Note " + "that 3D points are not saved.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("input_folder") + parser.add_argument("output_file") + + parser.add_argument("--sort", default=False, action="store_true", + help="sort the images by their filename") + + args = parser.parse_args() + + main(args) diff --git a/internal/pycolmap/tools/write_depthmap_to_ply.py b/internal/pycolmap/tools/write_depthmap_to_ply.py new file mode 100644 index 0000000000000000000000000000000000000000..967eef0464aa60444076b8fcedae4a378943b126 --- /dev/null +++ b/internal/pycolmap/tools/write_depthmap_to_ply.py @@ -0,0 +1,139 @@ +import sys +sys.path.append("..") + +import imageio +import numpy as np +import os + +from plyfile import PlyData, PlyElement +from pycolmap import SceneManager +from scipy.ndimage.interpolation import zoom + + +#------------------------------------------------------------------------------- + +def main(args): + suffix = ".photometric.bin" if args.photometric else ".geometric.bin" + + image_file = os.path.join(args.dense_folder, "images", args.image_filename) + depth_file = os.path.join( + args.dense_folder, args.stereo_folder, "depth_maps", + args.image_filename + suffix) + if args.save_normals: + normals_file = os.path.join( + args.dense_folder, args.stereo_folder, "normal_maps", + args.image_filename + suffix) + + # load camera intrinsics from the COLMAP reconstruction + scene_manager = SceneManager(os.path.join(args.dense_folder, "sparse")) + scene_manager.load_cameras() + scene_manager.load_images() + + image_id, image = scene_manager.get_image_from_name(args.image_filename) + camera = scene_manager.cameras[image.camera_id] + rotation_camera_from_world = image.R() + camera_center = image.C() + + # load image, depth map, and normal map + image = imageio.imread(image_file) + + with open(depth_file, "rb") as fid: + w = int("".join(iter(lambda: fid.read(1), "&"))) + h = int("".join(iter(lambda: fid.read(1), "&"))) + c = int("".join(iter(lambda: fid.read(1), "&"))) + depth_map = np.fromfile(fid, np.float32).reshape(h, w) + if (h, w) != image.shape[:2]: + depth_map = zoom( + depth_map, + (float(image.shape[0]) / h, float(image.shape[1]) / w), + order=0) + + if args.save_normals: + with open(normals_file, "rb") as fid: + w = int("".join(iter(lambda: fid.read(1), "&"))) + h = int("".join(iter(lambda: fid.read(1), "&"))) + c = int("".join(iter(lambda: fid.read(1), "&"))) + normals = np.fromfile( + fid, np.float32).reshape(c, h, w).transpose([1, 2, 0]) + if (h, w) != image.shape[:2]: + normals = zoom( + normals, + (float(image.shape[0]) / h, float(image.shape[1]) / w, 1.), + order=0) + + if args.min_depth is not None: + depth_map[depth_map < args.min_depth] = 0. + if args.max_depth is not None: + depth_map[depth_map > args.max_depth] = 0. + + # create 3D points + #depth_map = np.minimum(depth_map, 100.) + points3D = np.dstack(camera.get_image_grid() + [depth_map]) + points3D[:,:,:2] *= depth_map[:,:,np.newaxis] + + # save + points3D = points3D.astype(np.float32).reshape(-1, 3) + if args.save_normals: + normals = normals.astype(np.float32).reshape(-1, 3) + image = image.reshape(-1, 3) + if image.dtype != np.uint8: + if image.max() <= 1: + image = (image * 255.).astype(np.uint8) + else: + image = image.astype(np.uint8) + + if args.world_space: + points3D = points3D.dot(rotation_camera_from_world) + camera_center + if args.save_normals: + normals = normals.dot(rotation_camera_from_world) + + if args.save_normals: + vertices = np.rec.fromarrays( + tuple(points3D.T) + tuple(normals.T) + tuple(image.T), + names="x,y,z,nx,ny,nz,red,green,blue") + else: + vertices = np.rec.fromarrays( + tuple(points3D.T) + tuple(image.T), names="x,y,z,red,green,blue") + vertices = PlyElement.describe(vertices, "vertex") + PlyData([vertices]).write(args.output_filename) + + +#------------------------------------------------------------------------------- + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument("dense_folder", type=str) + parser.add_argument("image_filename", type=str) + parser.add_argument("output_filename", type=str) + + parser.add_argument( + "--photometric", default=False, action="store_true", + help="use photometric depthmap instead of geometric") + + parser.add_argument( + "--world_space", default=False, action="store_true", + help="apply the camera->world extrinsic transformation to the result") + + parser.add_argument( + "--save_normals", default=False, action="store_true", + help="load the estimated normal map and save as part of the PLY") + + parser.add_argument( + "--stereo_folder", type=str, default="stereo", + help="folder in the dense workspace containing depth and normal maps") + + parser.add_argument( + "--min_depth", type=float, default=None, + help="set pixels with depth less than this value to zero depth") + + parser.add_argument( + "--max_depth", type=float, default=None, + help="set pixels with depth greater than this value to zero depth") + + args = parser.parse_args() + + main(args) diff --git a/internal/raw_utils.py b/internal/raw_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4406cd118f917f3fa12fca91dd95e62bd0bcc0fb --- /dev/null +++ b/internal/raw_utils.py @@ -0,0 +1,360 @@ +import glob +import json +import os +from internal import image as lib_image +from internal import math +from internal import utils +import numpy as np +import rawpy + + +def postprocess_raw(raw, camtorgb, exposure=None): + """Converts demosaicked raw to sRGB with a minimal postprocessing pipeline. + + Args: + raw: [H, W, 3], demosaicked raw camera image. + camtorgb: [3, 3], color correction transformation to apply to raw image. + exposure: color value to be scaled to pure white after color correction. + If None, "autoexposes" at the 97th percentile. + + Returns: + srgb: [H, W, 3], color corrected + exposed + gamma mapped image. + """ + if raw.shape[-1] != 3: + raise ValueError(f'raw.shape[-1] is {raw.shape[-1]}, expected 3') + if camtorgb.shape != (3, 3): + raise ValueError(f'camtorgb.shape is {camtorgb.shape}, expected (3, 3)') + # Convert from camera color space to standard linear RGB color space. + rgb_linear = np.matmul(raw, camtorgb.T) + if exposure is None: + exposure = np.percentile(rgb_linear, 97) + # "Expose" image by mapping the input exposure level to white and clipping. + rgb_linear_scaled = np.clip(rgb_linear / exposure, 0, 1) + # Apply sRGB gamma curve to serve as a simple tonemap. + srgb = lib_image.linear_to_srgb_np(rgb_linear_scaled) + return srgb + + +def pixels_to_bayer_mask(pix_x, pix_y): + """Computes binary RGB Bayer mask values from integer pixel coordinates.""" + # Red is top left (0, 0). + r = (pix_x % 2 == 0) * (pix_y % 2 == 0) + # Green is top right (0, 1) and bottom left (1, 0). + g = (pix_x % 2 == 1) * (pix_y % 2 == 0) + (pix_x % 2 == 0) * (pix_y % 2 == 1) + # Blue is bottom right (1, 1). + b = (pix_x % 2 == 1) * (pix_y % 2 == 1) + return np.stack([r, g, b], -1).astype(np.float32) + + +def bilinear_demosaic(bayer): + """Converts Bayer data into a full RGB image using bilinear demosaicking. + + Input data should be ndarray of shape [height, width] with 2x2 mosaic pattern: + ------------- + |red |green| + ------------- + |green|blue | + ------------- + Red and blue channels are bilinearly upsampled 2x, missing green channel + elements are the average of the neighboring 4 values in a cross pattern. + + Args: + bayer: [H, W] array, Bayer mosaic pattern input image. + + Returns: + rgb: [H, W, 3] array, full RGB image. + """ + + def reshape_quads(*planes): + """Reshape pixels from four input images to make tiled 2x2 quads.""" + planes = np.stack(planes, -1) + shape = planes.shape[:-1] + # Create [2, 2] arrays out of 4 channels. + zup = planes.reshape(shape + (2, 2,)) + # Transpose so that x-axis dimensions come before y-axis dimensions. + zup = np.transpose(zup, (0, 2, 1, 3)) + # Reshape to 2D. + zup = zup.reshape((shape[0] * 2, shape[1] * 2)) + return zup + + def bilinear_upsample(z): + """2x bilinear image upsample.""" + # Using np.roll makes the right and bottom edges wrap around. The raw image + # data has a few garbage columns/rows at the edges that must be discarded + # anyway, so this does not matter in practice. + # Horizontally interpolated values. + zx = .5 * (z + np.roll(z, -1, axis=-1)) + # Vertically interpolated values. + zy = .5 * (z + np.roll(z, -1, axis=-2)) + # Diagonally interpolated values. + zxy = .5 * (zx + np.roll(zx, -1, axis=-2)) + return reshape_quads(z, zx, zy, zxy) + + def upsample_green(g1, g2): + """Special 2x upsample from the two green channels.""" + z = np.zeros_like(g1) + z = reshape_quads(z, g1, g2, z) + alt = 0 + # Grab the 4 directly adjacent neighbors in a "cross" pattern. + for i in range(4): + axis = -1 - (i // 2) + roll = -1 + 2 * (i % 2) + alt = alt + .25 * np.roll(z, roll, axis=axis) + # For observed pixels, alt = 0, and for unobserved pixels, alt = avg(cross), + # so alt + z will have every pixel filled in. + return alt + z + + r, g1, g2, b = [bayer[(i // 2)::2, (i % 2)::2] for i in range(4)] + r = bilinear_upsample(r) + # Flip in x and y before and after calling upsample, as bilinear_upsample + # assumes that the samples are at the top-left corner of the 2x2 sample. + b = bilinear_upsample(b[::-1, ::-1])[::-1, ::-1] + g = upsample_green(g1, g2) + rgb = np.stack([r, g, b], -1) + return rgb + + +def load_raw_images(image_dir, image_names=None): + """Loads raw images and their metadata from disk. + + Args: + image_dir: directory containing raw image and EXIF data. + image_names: files to load (ignores file extension), loads all DNGs if None. + + Returns: + A tuple (images, exifs). + images: [N, height, width, 3] array of raw sensor data. + exifs: [N] list of dicts, one per image, containing the EXIF data. + Raises: + ValueError: The requested `image_dir` does not exist on disk. + """ + + if not utils.file_exists(image_dir): + raise ValueError(f'Raw image folder {image_dir} does not exist.') + + # Load raw images (dng files) and exif metadata (json files). + def load_raw_exif(image_name): + base = os.path.join(image_dir, os.path.splitext(image_name)[0]) + with utils.open_file(base + '.dng', 'rb') as f: + raw = rawpy.imread(f).raw_image + with utils.open_file(base + '.json', 'rb') as f: + exif = json.load(f)[0] + return raw, exif + + if image_names is None: + image_names = [ + os.path.basename(f) + for f in sorted(glob.glob(os.path.join(image_dir, '*.dng'))) + ] + + data = [load_raw_exif(x) for x in image_names] + raws, exifs = zip(*data) + raws = np.stack(raws, axis=0).astype(np.float32) + + return raws, exifs + + +# Brightness percentiles to use for re-exposing and tonemapping raw images. +_PERCENTILE_LIST = (80, 90, 97, 99, 100) + +# Relevant fields to extract from raw image EXIF metadata. +# For details regarding EXIF parameters, see: +# https://www.adobe.com/content/dam/acom/en/products/photoshop/pdfs/dng_spec_1.4.0.0.pdf. +_EXIF_KEYS = ( + 'BlackLevel', # Black level offset added to sensor measurements. + 'WhiteLevel', # Maximum possible sensor measurement. + 'AsShotNeutral', # RGB white balance coefficients. + 'ColorMatrix2', # XYZ to camera color space conversion matrix. + 'NoiseProfile', # Shot and read noise levels. +) + +# Color conversion from reference illuminant XYZ to RGB color space. +# See http://www.brucelindbloom.com/index.html?Eqn_RGB_XYZ_Matrix.html. +_RGB2XYZ = np.array([[0.4124564, 0.3575761, 0.1804375], + [0.2126729, 0.7151522, 0.0721750], + [0.0193339, 0.1191920, 0.9503041]]) + + +def process_exif(exifs): + """Processes list of raw image EXIF data into useful metadata dict. + + Input should be a list of dictionaries loaded from JSON files. + These JSON files are produced by running + $ exiftool -json IMAGE.dng > IMAGE.json + for each input raw file. + + We extract only the parameters relevant to + 1. Rescaling the raw data to [0, 1], + 2. White balance and color correction, and + 3. Noise level estimation. + + Args: + exifs: a list of dicts containing EXIF data as loaded from JSON files. + + Returns: + meta: a dict of the relevant metadata for running RawNeRF. + """ + meta = {} + exif = exifs[0] + # Convert from array of dicts (exifs) to dict of arrays (meta). + for key in _EXIF_KEYS: + exif_value = exif.get(key) + if exif_value is None: + continue + # Values can be a single int or float... + if isinstance(exif_value, int) or isinstance(exif_value, float): + vals = [x[key] for x in exifs] + # Or a string of numbers with ' ' between. + elif isinstance(exif_value, str): + vals = [[float(z) for z in x[key].split(' ')] for x in exifs] + meta[key] = np.squeeze(np.array(vals)) + # Shutter speed is a special case, a string written like 1/N. + meta['ShutterSpeed'] = np.fromiter( + (1. / float(exif['ShutterSpeed'].split('/')[1]) for exif in exifs), float) + + # Create raw-to-sRGB color transform matrices. Pipeline is: + # cam space -> white balanced cam space ("camwb") -> XYZ space -> RGB space. + # 'AsShotNeutral' is an RGB triplet representing how pure white would measure + # on the sensor, so dividing by these numbers corrects the white balance. + whitebalance = meta['AsShotNeutral'].reshape(-1, 3) + cam2camwb = np.array([np.diag(1. / x) for x in whitebalance]) + # ColorMatrix2 converts from XYZ color space to "reference illuminant" (white + # balanced) camera space. + xyz2camwb = meta['ColorMatrix2'].reshape(-1, 3, 3) + rgb2camwb = xyz2camwb @ _RGB2XYZ + # We normalize the rows of the full color correction matrix, as is done in + # https://github.com/AbdoKamel/simple-camera-pipeline. + rgb2camwb /= rgb2camwb.sum(axis=-1, keepdims=True) + # Combining color correction with white balance gives the entire transform. + cam2rgb = np.linalg.inv(rgb2camwb) @ cam2camwb + meta['cam2rgb'] = cam2rgb + + return meta + + +def load_raw_dataset(split, data_dir, image_names, exposure_percentile, n_downsample): + """Loads and processes a set of RawNeRF input images. + + Includes logic necessary for special "test" scenes that include a noiseless + ground truth frame, produced by HDR+ merge. + + Args: + split: DataSplit.TRAIN or DataSplit.TEST, only used for test scene logic. + data_dir: base directory for scene data. + image_names: which images were successfully posed by COLMAP. + exposure_percentile: what brightness percentile to expose to white. + n_downsample: returned images are downsampled by a factor of n_downsample. + + Returns: + A tuple (images, meta, testscene). + images: [N, height // n_downsample, width // n_downsample, 3] array of + demosaicked raw image data. + meta: EXIF metadata and other useful processing parameters. Includes per + image exposure information that can be passed into the NeRF model with + each ray: the set of unique exposure times is determined and each image + assigned a corresponding exposure index (mapping to an exposure value). + These are keys 'unique_shutters', 'exposure_idx', and 'exposure_value' in + the `meta` dictionary. + We rescale so the maximum `exposure_value` is 1 for convenience. + testscene: True when dataset includes ground truth test image, else False. + """ + + image_dir = os.path.join(data_dir, 'raw') + + testimg_file = os.path.join(data_dir, 'hdrplus_test/merged.dng') + testscene = utils.file_exists(testimg_file) + if testscene: + # Test scenes have train/ and test/ split subdirectories inside raw/. + image_dir = os.path.join(image_dir, split.value) + if split == utils.DataSplit.TEST: + # COLMAP image names not valid for test split of test scene. + image_names = None + else: + # Discard the first COLMAP image name as it is a copy of the test image. + image_names = image_names[1:] + + raws, exifs = load_raw_images(image_dir, image_names) + meta = process_exif(exifs) + + if testscene and split == utils.DataSplit.TEST: + # Test split for test scene must load the "ground truth" HDR+ merged image. + with utils.open_file(testimg_file, 'rb') as imgin: + testraw = rawpy.imread(imgin).raw_image + # HDR+ output has 2 extra bits of fixed precision, need to divide by 4. + testraw = testraw.astype(np.float32) / 4. + # Need to rescale long exposure test image by fast:slow shutter speed ratio. + fast_shutter = meta['ShutterSpeed'][0] + slow_shutter = meta['ShutterSpeed'][-1] + shutter_ratio = fast_shutter / slow_shutter + # Replace loaded raws with the "ground truth" test image. + raws = testraw[None] + # Test image shares metadata with the first loaded image (fast exposure). + meta = {k: meta[k][:1] for k in meta} + else: + shutter_ratio = 1. + + # Next we determine an index for each unique shutter speed in the data. + shutter_speeds = meta['ShutterSpeed'] + # Sort the shutter speeds from slowest (largest) to fastest (smallest). + # This way index 0 will always correspond to the brightest image. + unique_shutters = np.sort(np.unique(shutter_speeds))[::-1] + exposure_idx = np.zeros_like(shutter_speeds, dtype=np.int32) + for i, shutter in enumerate(unique_shutters): + # Assign index `i` to all images with shutter speed `shutter`. + exposure_idx[shutter_speeds == shutter] = i + meta['exposure_idx'] = exposure_idx + meta['unique_shutters'] = unique_shutters + # Rescale to use relative shutter speeds, where 1. is the brightest. + # This way the NeRF output with exposure=1 will always be reasonable. + meta['exposure_values'] = shutter_speeds / unique_shutters[0] + + # Rescale raw sensor measurements to [0, 1] (plus noise). + blacklevel = meta['BlackLevel'].reshape(-1, 1, 1) + whitelevel = meta['WhiteLevel'].reshape(-1, 1, 1) + images = (raws - blacklevel) / (whitelevel - blacklevel) * shutter_ratio + + # Calculate value for exposure level when gamma mapping, defaults to 97%. + # Always based on full resolution image 0 (for consistency). + image0_raw_demosaic = np.array(bilinear_demosaic(images[0])) + image0_rgb = image0_raw_demosaic @ meta['cam2rgb'][0].T + exposure = np.percentile(image0_rgb, exposure_percentile) + meta['exposure'] = exposure + # Sweep over various exposure percentiles to visualize in training logs. + exposure_levels = {p: np.percentile(image0_rgb, p) for p in _PERCENTILE_LIST} + meta['exposure_levels'] = exposure_levels + + # Create postprocessing function mapping raw images to tonemapped sRGB space. + cam2rgb0 = meta['cam2rgb'][0] + meta['postprocess_fn'] = lambda z, x=exposure: postprocess_raw(z, cam2rgb0, x) + + def processing_fn(x): + x_ = np.array(x) + x_demosaic = bilinear_demosaic(x_) + if n_downsample > 1: + x_demosaic = lib_image.downsample(x_demosaic, n_downsample) + return np.array(x_demosaic) + + images = np.stack([processing_fn(im) for im in images], axis=0) + + return images, meta, testscene + + +def best_fit_affine(x, y, axis): + """Computes best fit a, b such that a * x + b = y, in a least square sense.""" + x_m = x.mean(axis=axis) + y_m = y.mean(axis=axis) + xy_m = (x * y).mean(axis=axis) + xx_m = (x * x).mean(axis=axis) + # slope a = Cov(x, y) / Cov(x, x). + a = (xy_m - x_m * y_m) / (xx_m - x_m * x_m) + b = y_m - a * x_m + return a, b + + +def match_images_affine(est, gt, axis=(0, 1)): + """Computes affine best fit of gt->est, then maps est back to match gt.""" + # Mapping is computed gt->est to be robust since `est` may be very noisy. + a, b = best_fit_affine(gt, est, axis=axis) + # Inverse mapping back to gt ensures we use a consistent space for metrics. + est_matched = (est - b) / a + return est_matched diff --git a/internal/ref_utils.py b/internal/ref_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..944977849b9a5840d10841eda375c81e5975bc08 --- /dev/null +++ b/internal/ref_utils.py @@ -0,0 +1,174 @@ +from internal import math +import torch +import numpy as np + + +def reflect(viewdirs, normals): + """Reflect view directions about normals. + + The reflection of a vector v about a unit vector n is a vector u such that + dot(v, n) = dot(u, n), and dot(u, u) = dot(v, v). The solution to these two + equations is u = 2 dot(n, v) n - v. + + Args: + viewdirs: [..., 3] array of view directions. + normals: [..., 3] array of normal directions (assumed to be unit vectors). + + Returns: + [..., 3] array of reflection directions. + """ + return 2.0 * torch.sum(normals * viewdirs, dim=-1, keepdim=True) * normals - viewdirs + + +def l2_normalize(x): + """Normalize x to unit length along last axis.""" + return torch.nn.functional.normalize(x, dim=-1, eps=torch.finfo(x.dtype).eps) + + +def l2_normalize_np(x): + """Normalize x to unit length along last axis.""" + return x / np.sqrt(np.maximum(np.sum(x ** 2, axis=-1, keepdims=True), np.finfo(x.dtype).eps)) + + +def compute_weighted_mae(weights, normals, normals_gt): + """Compute weighted mean angular error, assuming normals are unit length.""" + one_eps = 1 - torch.finfo(weights.dtype).eps + return (weights * torch.arccos(torch.clip((normals * normals_gt).sum(-1), + -one_eps, one_eps))).sum() / weights.sum() * 180.0 / torch.pi + + +def compute_weighted_mae_np(weights, normals, normals_gt): + """Compute weighted mean angular error, assuming normals are unit length.""" + one_eps = 1 - np.finfo(weights.dtype).eps + return (weights * np.arccos(np.clip((normals * normals_gt).sum(-1), + -one_eps, one_eps))).sum() / weights.sum() * 180.0 / np.pi + + +def generalized_binomial_coeff(a, k): + """Compute generalized binomial coefficients.""" + return np.prod(a - np.arange(k)) / np.math.factorial(k) + + +def assoc_legendre_coeff(l, m, k): + """Compute associated Legendre polynomial coefficients. + + Returns the coefficient of the cos^k(theta)*sin^m(theta) term in the + (l, m)th associated Legendre polynomial, P_l^m(cos(theta)). + + Args: + l: associated Legendre polynomial degree. + m: associated Legendre polynomial order. + k: power of cos(theta). + + Returns: + A float, the coefficient of the term corresponding to the inputs. + """ + return ((-1) ** m * 2 ** l * np.math.factorial(l) / np.math.factorial(k) / + np.math.factorial(l - k - m) * + generalized_binomial_coeff(0.5 * (l + k + m - 1.0), l)) + + +def sph_harm_coeff(l, m, k): + """Compute spherical harmonic coefficients.""" + return (np.sqrt( + (2.0 * l + 1.0) * np.math.factorial(l - m) / + (4.0 * np.pi * np.math.factorial(l + m))) * assoc_legendre_coeff(l, m, k)) + + +def get_ml_array(deg_view): + """Create a list with all pairs of (l, m) values to use in the encoding.""" + ml_list = [] + for i in range(deg_view): + l = 2 ** i + # Only use nonnegative m values, later splitting real and imaginary parts. + for m in range(l + 1): + ml_list.append((m, l)) + + # Convert list into a numpy array. + ml_array = np.array(ml_list).T + return ml_array + + +def generate_ide_fn(deg_view): + """Generate integrated directional encoding (IDE) function. + + This function returns a function that computes the integrated directional + encoding from Equations 6-8 of arxiv.org/abs/2112.03907. + + Args: + deg_view: number of spherical harmonics degrees to use. + + Returns: + A function for evaluating integrated directional encoding. + + Raises: + ValueError: if deg_view is larger than 5. + """ + if deg_view > 5: + raise ValueError('Only deg_view of at most 5 is numerically stable.') + + ml_array = get_ml_array(deg_view) + l_max = 2 ** (deg_view - 1) + + # Create a matrix corresponding to ml_array holding all coefficients, which, + # when multiplied (from the right) by the z coordinate Vandermonde matrix, + # results in the z component of the encoding. + mat = np.zeros((l_max + 1, ml_array.shape[1])) + for i, (m, l) in enumerate(ml_array.T): + for k in range(l - m + 1): + mat[k, i] = sph_harm_coeff(l, m, k) + mat = torch.from_numpy(mat).float() + ml_array = torch.from_numpy(ml_array).float() + + def integrated_dir_enc_fn(xyz, kappa_inv): + """Function returning integrated directional encoding (IDE). + + Args: + xyz: [..., 3] array of Cartesian coordinates of directions to evaluate at. + kappa_inv: [..., 1] reciprocal of the concentration parameter of the von + Mises-Fisher distribution. + + Returns: + An array with the resulting IDE. + """ + x = xyz[..., 0:1] + y = xyz[..., 1:2] + z = xyz[..., 2:3] + + # Compute z Vandermonde matrix. + vmz = torch.cat([z ** i for i in range(mat.shape[0])], dim=-1) + + # Compute x+iy Vandermonde matrix. + vmxy = torch.cat([(x + 1j * y) ** m for m in ml_array[0, :]], dim=-1) + + # Get spherical harmonics. + sph_harms = vmxy * math.matmul(vmz, mat.to(xyz.device)) + + # Apply attenuation function using the von Mises-Fisher distribution + # concentration parameter, kappa. + sigma = 0.5 * ml_array[1, :] * (ml_array[1, :] + 1) + sigma = sigma.to(sph_harms.device) + ide = sph_harms * torch.exp(-sigma * kappa_inv) + + # Split into real and imaginary parts and return + return torch.cat([torch.real(ide), torch.imag(ide)], dim=-1) + + return integrated_dir_enc_fn + + +def generate_dir_enc_fn(deg_view): + """Generate directional encoding (DE) function. + + Args: + deg_view: number of spherical harmonics degrees to use. + + Returns: + A function for evaluating directional encoding. + """ + integrated_dir_enc_fn = generate_ide_fn(deg_view) + + def dir_enc_fn(xyz): + """Function returning directional encoding (DE).""" + return integrated_dir_enc_fn(xyz, torch.zeros_like(xyz[..., :1])) + + return dir_enc_fn diff --git a/internal/render.py b/internal/render.py new file mode 100644 index 0000000000000000000000000000000000000000..e158db7d77bcad15ebe01ff8093e1e1135857164 --- /dev/null +++ b/internal/render.py @@ -0,0 +1,242 @@ +import os.path + +from internal import stepfun +from internal import math +from internal import utils +import torch +import torch.nn.functional as F + + +def lift_gaussian(d, t_mean, t_var, r_var, diag): + """Lift a Gaussian defined along a ray to 3D coordinates.""" + mean = d[..., None, :] * t_mean[..., None] + eps = torch.finfo(d.dtype).eps + # eps = 1e-3 + d_mag_sq = torch.sum(d ** 2, dim=-1, keepdim=True).clamp_min(eps) + + if diag: + d_outer_diag = d ** 2 + null_outer_diag = 1 - d_outer_diag / d_mag_sq + t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :] + xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :] + cov_diag = t_cov_diag + xy_cov_diag + return mean, cov_diag + else: + d_outer = d[..., :, None] * d[..., None, :] + eye = torch.eye(d.shape[-1], device=d.device) + null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :] + t_cov = t_var[..., None, None] * d_outer[..., None, :, :] + xy_cov = r_var[..., None, None] * null_outer[..., None, :, :] + cov = t_cov + xy_cov + return mean, cov + + +def conical_frustum_to_gaussian(d, t0, t1, base_radius, diag, stable=True): + """Approximate a conical frustum as a Gaussian distribution (mean+cov). + + Assumes the ray is originating from the origin, and base_radius is the + radius at dist=1. Doesn't assume `d` is normalized. + + Args: + d: the axis of the cone + t0: the starting distance of the frustum. + t1: the ending distance of the frustum. + base_radius: the scale of the radius as a function of distance. + diag: whether or the Gaussian will be diagonal or full-covariance. + stable: whether or not to use the stable computation described in + the paper (setting this to False will cause catastrophic failure). + + Returns: + a Gaussian (mean and covariance). + """ + if stable: + # Equation 7 in the paper (https://arxiv.org/abs/2103.13415). + mu = (t0 + t1) / 2 # The average of the two `t` values. + hw = (t1 - t0) / 2 # The half-width of the two `t` values. + eps = torch.finfo(d.dtype).eps + # eps = 1e-3 + t_mean = mu + (2 * mu * hw ** 2) / (3 * mu ** 2 + hw ** 2).clamp_min(eps) + denom = (3 * mu ** 2 + hw ** 2).clamp_min(eps) + t_var = (hw ** 2) / 3 - (4 / 15) * hw ** 4 * (12 * mu ** 2 - hw ** 2) / denom ** 2 + r_var = (mu ** 2) / 4 + (5 / 12) * hw ** 2 - (4 / 15) * (hw ** 4) / denom + else: + # Equations 37-39 in the paper. + t_mean = (3 * (t1 ** 4 - t0 ** 4)) / (4 * (t1 ** 3 - t0 ** 3)) + r_var = 3 / 20 * (t1 ** 5 - t0 ** 5) / (t1 ** 3 - t0 ** 3) + t_mosq = 3 / 5 * (t1 ** 5 - t0 ** 5) / (t1 ** 3 - t0 ** 3) + t_var = t_mosq - t_mean ** 2 + r_var *= base_radius ** 2 + return lift_gaussian(d, t_mean, t_var, r_var, diag) + + +def cylinder_to_gaussian(d, t0, t1, radius, diag): + """Approximate a cylinder as a Gaussian distribution (mean+cov). + + Assumes the ray is originating from the origin, and radius is the + radius. Does not renormalize `d`. + + Args: + d: the axis of the cylinder + t0: the starting distance of the cylinder. + t1: the ending distance of the cylinder. + radius: the radius of the cylinder + diag: whether or the Gaussian will be diagonal or full-covariance. + + Returns: + a Gaussian (mean and covariance). + """ + t_mean = (t0 + t1) / 2 + r_var = radius ** 2 / 4 + t_var = (t1 - t0) ** 2 / 12 + return lift_gaussian(d, t_mean, t_var, r_var, diag) + + +def cast_rays(tdist, origins, directions, cam_dirs, radii, rand=True, n=7, m=3, std_scale=0.5, **kwargs): + """Cast rays (cone- or cylinder-shaped) and featurize sections of it. + + Args: + tdist: float array, the "fencepost" distances along the ray. + origins: float array, the ray origin coordinates. + directions: float array, the ray direction vectors. + radii: float array, the radii (base radii for cones) of the rays. + ray_shape: string, the shape of the ray, must be 'cone' or 'cylinder'. + diag: boolean, whether or not the covariance matrices should be diagonal. + + Returns: + a tuple of arrays of means and covariances. + """ + t0 = tdist[..., :-1, None] + t1 = tdist[..., 1:, None] + radii = radii[..., None] + + t_m = (t0 + t1) / 2 + t_d = (t1 - t0) / 2 + + j = torch.arange(6, device=tdist.device) + t = t0 + t_d / (t_d ** 2 + 3 * t_m ** 2) * (t1 ** 2 + 2 * t_m ** 2 + 3 / 7 ** 0.5 * (2 * j / 5 - 1) * ( + (t_d ** 2 - t_m ** 2) ** 2 + 4 * t_m ** 4).sqrt()) + + deg = torch.pi / 3 * torch.tensor([0, 2, 4, 3, 5, 1], device=tdist.device, dtype=torch.float) + deg = torch.broadcast_to(deg, t.shape) + if rand: + # randomly rotate and flip + mask = torch.rand_like(t0[..., 0]) > 0.5 + deg = deg + 2 * torch.pi * torch.rand_like(deg[..., 0])[..., None] + deg = torch.where(mask[..., None], deg, torch.pi * 5 / 3 - deg) + else: + # rotate 30 degree and flip every other pattern + mask = torch.arange(t.shape[-2], device=tdist.device) % 2 == 0 + mask = torch.broadcast_to(mask, t.shape[:-1]) + deg = torch.where(mask[..., None], deg, deg + torch.pi / 6) + deg = torch.where(mask[..., None], deg, torch.pi * 5 / 3 - deg) + means = torch.stack([ + radii * t * torch.cos(deg) / 2 ** 0.5, + radii * t * torch.sin(deg) / 2 ** 0.5, + t + ], dim=-1) + stds = std_scale * radii * t / 2 ** 0.5 + + # two basis in parallel to the image plane + rand_vec = torch.randn_like(cam_dirs) + ortho1 = F.normalize(torch.cross(cam_dirs, rand_vec, dim=-1), dim=-1) + ortho2 = F.normalize(torch.cross(cam_dirs, ortho1, dim=-1), dim=-1) + + # just use directions to be the third vector of the orthonormal basis, + # while the cross section of cone is parallel to the image plane + basis_matrix = torch.stack([ortho1, ortho2, directions], dim=-1) + means = math.matmul(means, basis_matrix[..., None, :, :].transpose(-1, -2)) + means = means + origins[..., None, None, :] + # import trimesh + # trimesh.Trimesh(means.reshape(-1, 3).detach().cpu().numpy()).export("test.ply", "ply") + + return means, stds, t + + +def compute_alpha_weights(density, tdist, dirs, opaque_background=False): + """Helper function for computing alpha compositing weights.""" + t_delta = tdist[..., 1:] - tdist[..., :-1] + delta = t_delta * torch.norm(dirs[..., None, :], dim=-1) + density_delta = density * delta + + if opaque_background: + # Equivalent to making the final t-interval infinitely wide. + density_delta = torch.cat([ + density_delta[..., :-1], + torch.full_like(density_delta[..., -1:], torch.inf) + ], dim=-1) + + alpha = 1 - torch.exp(-density_delta) + trans = torch.exp(-torch.cat([ + torch.zeros_like(density_delta[..., :1]), + torch.cumsum(density_delta[..., :-1], dim=-1) + ], dim=-1)) + weights = alpha * trans + return weights, alpha, trans + + +def volumetric_rendering(rgbs, + weights, + tdist, + bg_rgbs, + t_far, + compute_extras, + extras=None): + """Volumetric Rendering Function. + + Args: + rgbs: color, [batch_size, num_samples, 3] + weights: weights, [batch_size, num_samples]. + tdist: [batch_size, num_samples]. + bg_rgbs: the color(s) to use for the background. + t_far: [batch_size, 1], the distance of the far plane. + compute_extras: bool, if True, compute extra quantities besides color. + extras: dict, a set of values along rays to render by alpha compositing. + + Returns: + rendering: a dict containing an rgb image of size [batch_size, 3], and other + visualizations if compute_extras=True. + """ + eps = torch.finfo(rgbs.dtype).eps + # eps = 1e-3 + rendering = {} + + acc = weights.sum(dim=-1) + bg_w = (1 - acc[..., None]).clamp_min(0.) # The weight of the background. + rgb = (weights[..., None] * rgbs).sum(dim=-2) + bg_w * bg_rgbs + t_mids = 0.5 * (tdist[..., :-1] + tdist[..., 1:]) + depth = ( + torch.clip( + torch.nan_to_num((weights * t_mids).sum(dim=-1) / acc.clamp_min(eps), torch.inf), + tdist[..., 0], tdist[..., -1])) + + rendering['rgb'] = rgb + rendering['depth'] = depth + rendering['acc'] = acc + + if compute_extras: + if extras is not None: + for k, v in extras.items(): + if v is not None: + rendering[k] = (weights[..., None] * v).sum(dim=-2) + + expectation = lambda x: (weights * x).sum(dim=-1) / acc.clamp_min(eps) + # For numerical stability this expectation is computing using log-distance. + rendering['distance_mean'] = ( + torch.clip( + torch.nan_to_num(torch.exp(expectation(torch.log(t_mids))), torch.inf), + tdist[..., 0], tdist[..., -1])) + + # Add an extra fencepost with the far distance at the end of each ray, with + # whatever weight is needed to make the new weight vector sum to exactly 1 + # (`weights` is only guaranteed to sum to <= 1, not == 1). + t_aug = torch.cat([tdist, t_far], dim=-1) + weights_aug = torch.cat([weights, bg_w], dim=-1) + + ps = [5, 50, 95] + distance_percentiles = stepfun.weighted_percentile(t_aug, weights_aug, ps) + + for i, p in enumerate(ps): + s = 'median' if p == 50 else 'percentile_' + str(p) + rendering['distance_' + s] = distance_percentiles[..., i] + + return rendering diff --git a/internal/stepfun.py b/internal/stepfun.py new file mode 100644 index 0000000000000000000000000000000000000000..2cc88945600d6636c5bcab9a86f2f33c2ae363b9 --- /dev/null +++ b/internal/stepfun.py @@ -0,0 +1,403 @@ +from internal import math +import numpy as np +import torch + + +def searchsorted(a, v): + """Find indices where v should be inserted into a to maintain order. + + Args: + a: tensor, the sorted reference points that we are scanning to see where v + should lie. + v: tensor, the query points that we are pretending to insert into a. Does + not need to be sorted. All but the last dimensions should match or expand + to those of a, the last dimension can differ. + + Returns: + (idx_lo, idx_hi), where a[idx_lo] <= v < a[idx_hi], unless v is out of the + range [a[0], a[-1]] in which case idx_lo and idx_hi are both the first or + last index of a. + """ + i = torch.arange(a.shape[-1], device=a.device) + v_ge_a = v[..., None, :] >= a[..., :, None] + idx_lo = torch.max(torch.where(v_ge_a, i[..., :, None], i[..., :1, None]), -2).values + idx_hi = torch.min(torch.where(~v_ge_a, i[..., :, None], i[..., -1:, None]), -2).values + return idx_lo, idx_hi + + +def query(tq, t, y, outside_value=0): + """Look up the values of the step function (t, y) at locations tq.""" + idx_lo, idx_hi = searchsorted(t, tq) + yq = torch.where(idx_lo == idx_hi, torch.full_like(idx_hi, outside_value), + torch.take_along_dim(y, idx_lo, dim=-1)) + return yq + + +def inner_outer(t0, t1, y1): + """Construct inner and outer measures on (t1, y1) for t0.""" + cy1 = torch.cat([torch.zeros_like(y1[..., :1]), + torch.cumsum(y1, dim=-1)], + dim=-1) + idx_lo, idx_hi = searchsorted(t1, t0) + + cy1_lo = torch.take_along_dim(cy1, idx_lo, dim=-1) + cy1_hi = torch.take_along_dim(cy1, idx_hi, dim=-1) + + y0_outer = cy1_hi[..., 1:] - cy1_lo[..., :-1] + y0_inner = torch.where(idx_hi[..., :-1] <= idx_lo[..., 1:], + cy1_lo[..., 1:] - cy1_hi[..., :-1], torch.zeros_like(idx_lo[..., 1:])) + return y0_inner, y0_outer + + +def lossfun_outer(t, w, t_env, w_env): + """The proposal weight should be an upper envelope on the nerf weight.""" + eps = torch.finfo(t.dtype).eps + # eps = 1e-3 + + _, w_outer = inner_outer(t, t_env, w_env) + # We assume w_inner <= w <= w_outer. We don't penalize w_inner because it's + # more effective to pull w_outer up than it is to push w_inner down. + # Scaled half-quadratic loss that gives a constant gradient at w_outer = 0. + return (w - w_outer).clamp_min(0) ** 2 / (w + eps) + + +def weight_to_pdf(t, w): + """Turn a vector of weights that sums to 1 into a PDF that integrates to 1.""" + eps = torch.finfo(t.dtype).eps + return w / (t[..., 1:] - t[..., :-1]).clamp_min(eps) + + +def pdf_to_weight(t, p): + """Turn a PDF that integrates to 1 into a vector of weights that sums to 1.""" + return p * (t[..., 1:] - t[..., :-1]) + + +def max_dilate(t, w, dilation, domain=(-torch.inf, torch.inf)): + """Dilate (via max-pooling) a non-negative step function.""" + t0 = t[..., :-1] - dilation + t1 = t[..., 1:] + dilation + t_dilate, _ = torch.sort(torch.cat([t, t0, t1], dim=-1), dim=-1) + t_dilate = torch.clip(t_dilate, *domain) + w_dilate = torch.max( + torch.where( + (t0[..., None, :] <= t_dilate[..., None]) + & (t1[..., None, :] > t_dilate[..., None]), + w[..., None, :], + torch.zeros_like(w[..., None, :]), + ), dim=-1).values[..., :-1] + return t_dilate, w_dilate + + +def max_dilate_weights(t, + w, + dilation, + domain=(-torch.inf, torch.inf), + renormalize=False): + """Dilate (via max-pooling) a set of weights.""" + eps = torch.finfo(w.dtype).eps + # eps = 1e-3 + + p = weight_to_pdf(t, w) + t_dilate, p_dilate = max_dilate(t, p, dilation, domain=domain) + w_dilate = pdf_to_weight(t_dilate, p_dilate) + if renormalize: + w_dilate /= torch.sum(w_dilate, dim=-1, keepdim=True).clamp_min(eps) + return t_dilate, w_dilate + + +def integrate_weights(w): + """Compute the cumulative sum of w, assuming all weight vectors sum to 1. + + The output's size on the last dimension is one greater than that of the input, + because we're computing the integral corresponding to the endpoints of a step + function, not the integral of the interior/bin values. + + Args: + w: Tensor, which will be integrated along the last axis. This is assumed to + sum to 1 along the last axis, and this function will (silently) break if + that is not the case. + + Returns: + cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1 + """ + cw = torch.cumsum(w[..., :-1], dim=-1).clamp_max(1) + shape = cw.shape[:-1] + (1,) + # Ensure that the CDF starts with exactly 0 and ends with exactly 1. + cw0 = torch.cat([torch.zeros(shape, device=cw.device), cw, + torch.ones(shape, device=cw.device)], dim=-1) + return cw0 + + +def integrate_weights_np(w): + """Compute the cumulative sum of w, assuming all weight vectors sum to 1. + + The output's size on the last dimension is one greater than that of the input, + because we're computing the integral corresponding to the endpoints of a step + function, not the integral of the interior/bin values. + + Args: + w: Tensor, which will be integrated along the last axis. This is assumed to + sum to 1 along the last axis, and this function will (silently) break if + that is not the case. + + Returns: + cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1 + """ + cw = np.minimum(1, np.cumsum(w[..., :-1], axis=-1)) + shape = cw.shape[:-1] + (1,) + # Ensure that the CDF starts with exactly 0 and ends with exactly 1. + cw0 = np.concatenate([np.zeros(shape), cw, + np.ones(shape)], axis=-1) + return cw0 + + +def invert_cdf(u, t, w_logits): + """Invert the CDF defined by (t, w) at the points specified by u in [0, 1).""" + # Compute the PDF and CDF for each weight vector. + w = torch.softmax(w_logits, dim=-1) + cw = integrate_weights(w) + # Interpolate into the inverse CDF. + t_new = math.sorted_interp(u, cw, t) + return t_new + + +def invert_cdf_np(u, t, w_logits): + """Invert the CDF defined by (t, w) at the points specified by u in [0, 1).""" + # Compute the PDF and CDF for each weight vector. + w = np.exp(w_logits) / np.exp(w_logits).sum(axis=-1, keepdims=True) + cw = integrate_weights_np(w) + # Interpolate into the inverse CDF. + interp_fn = np.interp + t_new = interp_fn(u, cw, t) + return t_new + + +def sample(rand, + t, + w_logits, + num_samples, + single_jitter=False, + deterministic_center=False): + """Piecewise-Constant PDF sampling from a step function. + + Args: + rand: random number generator (or None for `linspace` sampling). + t: [..., num_bins + 1], bin endpoint coordinates (must be sorted) + w_logits: [..., num_bins], logits corresponding to bin weights + num_samples: int, the number of samples. + single_jitter: bool, if True, jitter every sample along each ray by the same + amount in the inverse CDF. Otherwise, jitter each sample independently. + deterministic_center: bool, if False, when `rand` is None return samples that + linspace the entire PDF. If True, skip the front and back of the linspace + so that the centers of each PDF interval are returned. + + Returns: + t_samples: [batch_size, num_samples]. + """ + eps = torch.finfo(t.dtype).eps + # eps = 1e-3 + + device = t.device + + # Draw uniform samples. + if not rand: + if deterministic_center: + pad = 1 / (2 * num_samples) + u = torch.linspace(pad, 1. - pad - eps, num_samples, device=device) + else: + u = torch.linspace(0, 1. - eps, num_samples, device=device) + u = torch.broadcast_to(u, t.shape[:-1] + (num_samples,)) + else: + # `u` is in [0, 1) --- it can be zero, but it can never be 1. + u_max = eps + (1 - eps) / num_samples + max_jitter = (1 - u_max) / (num_samples - 1) - eps + d = 1 if single_jitter else num_samples + u = torch.linspace(0, 1 - u_max, num_samples, device=device) + \ + torch.rand(t.shape[:-1] + (d,), device=device) * max_jitter + + return invert_cdf(u, t, w_logits) + + +def sample_np(rand, + t, + w_logits, + num_samples, + single_jitter=False, + deterministic_center=False): + """ + numpy version of sample() + """ + eps = np.finfo(np.float32).eps + + # Draw uniform samples. + if not rand: + if deterministic_center: + pad = 1 / (2 * num_samples) + u = np.linspace(pad, 1. - pad - eps, num_samples) + else: + u = np.linspace(0, 1. - eps, num_samples) + u = np.broadcast_to(u, t.shape[:-1] + (num_samples,)) + else: + # `u` is in [0, 1) --- it can be zero, but it can never be 1. + u_max = eps + (1 - eps) / num_samples + max_jitter = (1 - u_max) / (num_samples - 1) - eps + d = 1 if single_jitter else num_samples + u = np.linspace(0, 1 - u_max, num_samples) + \ + np.random.rand(*t.shape[:-1], d) * max_jitter + + return invert_cdf_np(u, t, w_logits) + + +def sample_intervals(rand, + t, + w_logits, + num_samples, + single_jitter=False, + domain=(-torch.inf, torch.inf)): + """Sample *intervals* (rather than points) from a step function. + + Args: + rand: random number generator (or None for `linspace` sampling). + t: [..., num_bins + 1], bin endpoint coordinates (must be sorted) + w_logits: [..., num_bins], logits corresponding to bin weights + num_samples: int, the number of intervals to sample. + single_jitter: bool, if True, jitter every sample along each ray by the same + amount in the inverse CDF. Otherwise, jitter each sample independently. + domain: (minval, maxval), the range of valid values for `t`. + + Returns: + t_samples: [batch_size, num_samples]. + """ + if num_samples <= 1: + raise ValueError(f'num_samples must be > 1, is {num_samples}.') + + # Sample a set of points from the step function. + centers = sample( + rand, + t, + w_logits, + num_samples, + single_jitter, + deterministic_center=True) + + # The intervals we return will span the midpoints of each adjacent sample. + mid = (centers[..., 1:] + centers[..., :-1]) / 2 + + # Each first/last fencepost is the reflection of the first/last midpoint + # around the first/last sampled center. We clamp to the limits of the input + # domain, provided by the caller. + minval, maxval = domain + first = (2 * centers[..., :1] - mid[..., :1]).clamp_min(minval) + last = (2 * centers[..., -1:] - mid[..., -1:]).clamp_max(maxval) + + t_samples = torch.cat([first, mid, last], dim=-1) + return t_samples + + +def lossfun_distortion(t, w): + """Compute iint w[i] w[j] |t[i] - t[j]| di dj.""" + # The loss incurred between all pairs of intervals. + ut = (t[..., 1:] + t[..., :-1]) / 2 + dut = torch.abs(ut[..., :, None] - ut[..., None, :]) + loss_inter = torch.sum(w * torch.sum(w[..., None, :] * dut, dim=-1), dim=-1) + + # The loss incurred within each individual interval with itself. + loss_intra = torch.sum(w ** 2 * (t[..., 1:] - t[..., :-1]), dim=-1) / 3 + + return loss_inter + loss_intra + + +def interval_distortion(t0_lo, t0_hi, t1_lo, t1_hi): + """Compute mean(abs(x-y); x in [t0_lo, t0_hi], y in [t1_lo, t1_hi]).""" + # Distortion when the intervals do not overlap. + d_disjoint = torch.abs((t1_lo + t1_hi) / 2 - (t0_lo + t0_hi) / 2) + + # Distortion when the intervals overlap. + d_overlap = (2 * + (torch.minimum(t0_hi, t1_hi) ** 3 - torch.maximum(t0_lo, t1_lo) ** 3) + + 3 * (t1_hi * t0_hi * torch.abs(t1_hi - t0_hi) + + t1_lo * t0_lo * torch.abs(t1_lo - t0_lo) + t1_hi * t0_lo * + (t0_lo - t1_hi) + t1_lo * t0_hi * + (t1_lo - t0_hi))) / (6 * (t0_hi - t0_lo) * (t1_hi - t1_lo)) + + # Are the two intervals not overlapping? + are_disjoint = (t0_lo > t1_hi) | (t1_lo > t0_hi) + + return torch.where(are_disjoint, d_disjoint, d_overlap) + + +def weighted_percentile(t, w, ps): + """Compute the weighted percentiles of a step function. w's must sum to 1.""" + cw = integrate_weights(w) + # We want to interpolate into the integrated weights according to `ps`. + fn = lambda cw_i, t_i: math.sorted_interp(torch.tensor(ps, device=t.device) / 100, cw_i, t_i) + # Vmap fn to an arbitrary number of leading dimensions. + cw_mat = cw.reshape([-1, cw.shape[-1]]) + t_mat = t.reshape([-1, t.shape[-1]]) + wprctile_mat = fn(cw_mat, t_mat) # TODO + wprctile = wprctile_mat.reshape(cw.shape[:-1] + (len(ps),)) + return wprctile + + +def resample(t, tp, vp, use_avg=False): + """Resample a step function defined by (tp, vp) into intervals t. + + Args: + t: tensor with shape (..., n+1), the endpoints to resample into. + tp: tensor with shape (..., m+1), the endpoints of the step function being + resampled. + vp: tensor with shape (..., m), the values of the step function being + resampled. + use_avg: bool, if False, return the sum of the step function for each + interval in `t`. If True, return the average, weighted by the width of + each interval in `t`. + eps: float, a small value to prevent division by zero when use_avg=True. + + Returns: + v: tensor with shape (..., n), the values of the resampled step function. + """ + eps = torch.finfo(t.dtype).eps + # eps = 1e-3 + + if use_avg: + wp = torch.diff(tp, dim=-1) + v_numer = resample(t, tp, vp * wp, use_avg=False) + v_denom = resample(t, tp, wp, use_avg=False) + v = v_numer / v_denom.clamp_min(eps) + return v + + acc = torch.cumsum(vp, dim=-1) + acc0 = torch.cat([torch.zeros(acc.shape[:-1] + (1,), device=acc.device), acc], dim=-1) + acc0_resampled = math.sorted_interp(t, tp, acc0) # TODO + v = torch.diff(acc0_resampled, dim=-1) + return v + + +def resample_np(t, tp, vp, use_avg=False): + """ + numpy version of resample + """ + eps = np.finfo(t.dtype).eps + if use_avg: + wp = np.diff(tp, axis=-1) + v_numer = resample_np(t, tp, vp * wp, use_avg=False) + v_denom = resample_np(t, tp, wp, use_avg=False) + v = v_numer / np.maximum(eps, v_denom) + return v + + acc = np.cumsum(vp, axis=-1) + acc0 = np.concatenate([np.zeros(acc.shape[:-1] + (1,)), acc], axis=-1) + acc0_resampled = np.vectorize(np.interp, signature='(n),(m),(m)->(n)')(t, tp, acc0) + v = np.diff(acc0_resampled, axis=-1) + return v + + +def blur_stepfun(x, y, r): + xr, xr_idx = torch.sort(torch.cat([x - r, x + r], dim=-1)) + y1 = (torch.cat([y, torch.zeros_like(y[..., :1])], dim=-1) - + torch.cat([torch.zeros_like(y[..., :1]), y], dim=-1)) / (2 * r) + y2 = torch.cat([y1, -y1], dim=-1).take_along_dim(xr_idx[..., :-1], dim=-1) + yr = torch.cumsum((xr[..., 1:] - xr[..., :-1]) * + torch.cumsum(y2, dim=-1), dim=-1).clamp_min(0) + yr = torch.cat([torch.zeros_like(yr[..., :1]), yr], dim=-1) + return xr, yr diff --git a/internal/train_utils.py b/internal/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a206009fc1325122405f4cf8933986b70b33ea3e --- /dev/null +++ b/internal/train_utils.py @@ -0,0 +1,263 @@ +import collections +import functools + +import torch.optim +from internal import camera_utils +from internal import configs +from internal import datasets +from internal import image +from internal import math +from internal import models +from internal import ref_utils +from internal import stepfun +from internal import utils +import numpy as np +from torch.utils._pytree import tree_map, tree_flatten +from torch_scatter import segment_coo + + +class GradientScaler(torch.autograd.Function): + @staticmethod + def forward(ctx, colors, sigmas, ray_dist): + ctx.save_for_backward(ray_dist) + return colors, sigmas + + @staticmethod + def backward(ctx, grad_output_colors, grad_output_sigmas): + (ray_dist,) = ctx.saved_tensors + scaling = torch.square(ray_dist).clamp(0, 1) + return grad_output_colors * scaling[..., None], grad_output_sigmas * scaling, None + + +def tree_reduce(fn, tree, initializer=0): + return functools.reduce(fn, tree_flatten(tree)[0], initializer) + + +def tree_sum(tree): + return tree_reduce(lambda x, y: x + y, tree, initializer=0) + + +def tree_norm_sq(tree): + return tree_sum(tree_map(lambda x: torch.sum(x ** 2), tree)) + + +def tree_norm(tree): + return torch.sqrt(tree_norm_sq(tree)) + + +def tree_abs_max(tree): + return tree_reduce( + lambda x, y: max(x, torch.abs(y).max().item()), tree, initializer=0) + + +def tree_len(tree): + return tree_sum(tree_map(lambda z: np.prod(z.shape), tree)) + + +def summarize_tree(tree, fn, ancestry=(), max_depth=3): + """Flatten 'tree' while 'fn'-ing values and formatting keys like/this.""" + stats = {} + for k, v in tree.items(): + name = ancestry + (k,) + stats['/'.join(name)] = fn(v) + if hasattr(v, 'items') and len(ancestry) < (max_depth - 1): + stats.update(summarize_tree(v, fn, ancestry=name, max_depth=max_depth)) + return stats + + +def compute_data_loss(batch, renderings, config): + """Computes data loss terms for RGB, normal, and depth outputs.""" + data_losses = [] + stats = collections.defaultdict(lambda: []) + + # lossmult can be used to apply a weight to each ray in the batch. + # For example: masking out rays, applying the Bayer mosaic mask, upweighting + # rays from lower resolution images and so on. + lossmult = batch['lossmult'] + lossmult = torch.broadcast_to(lossmult, batch['rgb'][..., :3].shape) + if config.disable_multiscale_loss: + lossmult = torch.ones_like(lossmult) + + for rendering in renderings: + resid_sq = (rendering['rgb'] - batch['rgb'][..., :3]) ** 2 + denom = lossmult.sum() + stats['mses'].append(((lossmult * resid_sq).sum() / denom).item()) + + if config.data_loss_type == 'mse': + # Mean-squared error (L2) loss. + data_loss = resid_sq + elif config.data_loss_type == 'charb': + # Charbonnier loss. + data_loss = torch.sqrt(resid_sq + config.charb_padding ** 2) + elif config.data_loss_type == 'rawnerf': + # Clip raw values against 1 to match sensor overexposure behavior. + rgb_render_clip = rendering['rgb'].clamp_max(1) + resid_sq_clip = (rgb_render_clip - batch['rgb'][..., :3]) ** 2 + # Scale by gradient of log tonemapping curve. + scaling_grad = 1. / (1e-3 + rgb_render_clip.detach()) + # Reweighted L2 loss. + data_loss = resid_sq_clip * scaling_grad ** 2 + else: + assert False + data_losses.append((lossmult * data_loss).sum() / denom) + + if config.compute_disp_metrics: + # Using mean to compute disparity, but other distance statistics can + # be used instead. + disp = 1 / (1 + rendering['distance_mean']) + stats['disparity_mses'].append(((disp - batch['disps']) ** 2).mean().item()) + + if config.compute_normal_metrics: + if 'normals' in rendering: + weights = rendering['acc'] * batch['alphas'] + normalized_normals_gt = ref_utils.l2_normalize(batch['normals']) + normalized_normals = ref_utils.l2_normalize(rendering['normals']) + normal_mae = ref_utils.compute_weighted_mae(weights, normalized_normals, + normalized_normals_gt) + else: + # If normals are not computed, set MAE to NaN. + normal_mae = torch.nan + stats['normal_maes'].append(normal_mae.item()) + + loss = ( + config.data_coarse_loss_mult * sum(data_losses[:-1]) + + config.data_loss_mult * data_losses[-1]) + + stats = {k: np.array(stats[k]) for k in stats} + return loss, stats + + +def interlevel_loss(ray_history, config): + """Computes the interlevel loss defined in mip-NeRF 360.""" + # Stop the gradient from the interlevel loss onto the NeRF MLP. + last_ray_results = ray_history[-1] + c = last_ray_results['sdist'].detach() + w = last_ray_results['weights'].detach() + loss_interlevel = 0. + for ray_results in ray_history[:-1]: + cp = ray_results['sdist'] + wp = ray_results['weights'] + loss_interlevel += stepfun.lossfun_outer(c, w, cp, wp).mean() + return config.interlevel_loss_mult * loss_interlevel + + +def anti_interlevel_loss(ray_history, config): + """Computes the interlevel loss defined in mip-NeRF 360.""" + last_ray_results = ray_history[-1] + c = last_ray_results['sdist'].detach() + w = last_ray_results['weights'].detach() + w_normalize = w / (c[..., 1:] - c[..., :-1]) + loss_anti_interlevel = 0. + for i, ray_results in enumerate(ray_history[:-1]): + cp = ray_results['sdist'] + wp = ray_results['weights'] + c_, w_ = stepfun.blur_stepfun(c, w_normalize, config.pulse_width[i]) + + # piecewise linear pdf to piecewise quadratic cdf + area = 0.5 * (w_[..., 1:] + w_[..., :-1]) * (c_[..., 1:] - c_[..., :-1]) + + cdf = torch.cat([torch.zeros_like(area[..., :1]), torch.cumsum(area, dim=-1)], dim=-1) + + # query piecewise quadratic interpolation + cdf_interp = math.sorted_interp_quad(cp, c_, w_, cdf) + # difference between adjacent interpolated values + w_s = torch.diff(cdf_interp, dim=-1) + + loss_anti_interlevel += ((w_s - wp).clamp_min(0) ** 2 / (wp + 1e-5)).mean() + return config.anti_interlevel_loss_mult * loss_anti_interlevel + + +def distortion_loss(ray_history, config): + """Computes the distortion loss regularizer defined in mip-NeRF 360.""" + last_ray_results = ray_history[-1] + c = last_ray_results['sdist'] + w = last_ray_results['weights'] + loss = stepfun.lossfun_distortion(c, w).mean() + return config.distortion_loss_mult * loss + + +def orientation_loss(batch, model, ray_history, config): + """Computes the orientation loss regularizer defined in ref-NeRF.""" + total_loss = 0. + for i, ray_results in enumerate(ray_history): + w = ray_results['weights'] + n = ray_results[config.orientation_loss_target] + if n is None: + raise ValueError('Normals cannot be None if orientation loss is on.') + # Negate viewdirs to represent normalized vectors from point to camera. + v = -1. * batch['viewdirs'] + n_dot_v = (n * v[..., None, :]).sum(dim=-1) + loss = (w * n_dot_v.clamp_min(0) ** 2).sum(dim=-1).mean() + if i < model.num_levels - 1: + total_loss += config.orientation_coarse_loss_mult * loss + else: + total_loss += config.orientation_loss_mult * loss + return total_loss + + +def hash_decay_loss(ray_history, config): + total_loss = 0. + for i, ray_results in enumerate(ray_history): + total_loss += config.hash_decay_mults * ray_results['loss_hash_decay'] + return total_loss + + +def opacity_loss(renderings, config): + total_loss = 0. + for i, rendering in enumerate(renderings): + o = rendering['acc'] + total_loss += config.opacity_loss_mult * (-o * torch.log(o + 1e-5)).mean() + return total_loss + + +def predicted_normal_loss(model, ray_history, config): + """Computes the predicted normal supervision loss defined in ref-NeRF.""" + total_loss = 0. + for i, ray_results in enumerate(ray_history): + w = ray_results['weights'] + n = ray_results['normals'] + n_pred = ray_results['normals_pred'] + if n is None or n_pred is None: + raise ValueError( + 'Predicted normals and gradient normals cannot be None if ' + 'predicted normal loss is on.') + loss = torch.mean((w * (1.0 - torch.sum(n * n_pred, dim=-1))).sum(dim=-1)) + if i < model.num_levels - 1: + total_loss += config.predicted_normal_coarse_loss_mult * loss + else: + total_loss += config.predicted_normal_loss_mult * loss + return total_loss + + +def clip_gradients(model, accelerator, config): + """Clips gradients of MLP based on norm and max value.""" + if config.grad_max_norm > 0 and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), config.grad_max_norm) + + if config.grad_max_val > 0 and accelerator.sync_gradients: + accelerator.clip_grad_value_(model.parameters(), config.grad_max_val) + + for param in model.parameters(): + param.grad.nan_to_num_() + + +def create_optimizer(config: configs.Config, model): + """Creates optax optimizer for model training.""" + adam_kwargs = { + 'betas': [config.adam_beta1, config.adam_beta2], + 'eps': config.adam_eps, + } + lr_kwargs = { + 'max_steps': config.max_steps, + 'lr_delay_steps': config.lr_delay_steps, + 'lr_delay_mult': config.lr_delay_mult, + } + + lr_fn_main = lambda step: math.learning_rate_decay( + step, + lr_init=config.lr_init, + lr_final=config.lr_final, + **lr_kwargs) + optimizer = torch.optim.Adam(model.parameters(), lr=config.lr_init, **adam_kwargs) + + return optimizer, lr_fn_main diff --git a/internal/utils.py b/internal/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aef3f19d84a2ebc0a7ff48cd6a40f741cb25085c --- /dev/null +++ b/internal/utils.py @@ -0,0 +1,119 @@ +import enum +import logging +import os + +import cv2 +import torch +import numpy as np +from PIL import ExifTags +from PIL import Image +import collections +import random +from internal import vis +from matplotlib import cm + + +class Timing: + """ + Timing environment + usage: + with Timing("message"): + your commands here + will print CUDA runtime in ms + """ + + def __init__(self, name): + self.name = name + + def __enter__(self): + self.start = torch.cuda.Event(enable_timing=True) + self.end = torch.cuda.Event(enable_timing=True) + self.start.record() + + def __exit__(self, type, value, traceback): + self.end.record() + torch.cuda.synchronize() + print(self.name, "elapsed", self.start.elapsed_time(self.end), "ms") + + +def handle_exception(exc_type, exc_value, exc_traceback): + logging.error("Error!", exc_info=(exc_type, exc_value, exc_traceback)) + + +def nan_sum(x): + return (torch.isnan(x) | torch.isinf(x)).sum() + + +def flatten_dict(d, parent_key='', sep='_'): + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, collections.abc.MutableMapping): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +class DataSplit(enum.Enum): + """Dataset split.""" + TRAIN = 'train' + TEST = 'test' + + +class BatchingMethod(enum.Enum): + """Draw rays randomly from a single image or all images, in each batch.""" + ALL_IMAGES = 'all_images' + SINGLE_IMAGE = 'single_image' + + +def open_file(pth, mode='r'): + return open(pth, mode=mode) + + +def file_exists(pth): + return os.path.exists(pth) + + +def listdir(pth): + return os.listdir(pth) + + +def isdir(pth): + return os.path.isdir(pth) + + +def makedirs(pth): + os.makedirs(pth, exist_ok=True) + + +def load_img(pth): + """Load an image and cast to float32.""" + image = np.array(Image.open(pth), dtype=np.float32) + return image + + +def load_exif(pth): + """Load EXIF data for an image.""" + with open_file(pth, 'rb') as f: + image_pil = Image.open(f) + exif_pil = image_pil._getexif() # pylint: disable=protected-access + if exif_pil is not None: + exif = { + ExifTags.TAGS[k]: v for k, v in exif_pil.items() if k in ExifTags.TAGS + } + else: + exif = {} + return exif + + +def save_img_u8(img, pth): + """Save an image (probably RGB) in [0, 1] to disk as a uint8 PNG.""" + Image.fromarray( + (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8)).save( + pth, 'PNG') + + +def save_img_f32(depthmap, pth, p=0.5): + """Save an image (probably a depthmap) to disk as a float32 TIFF.""" + Image.fromarray(np.nan_to_num(depthmap).astype(np.float32)).save(pth, 'TIFF') diff --git a/internal/vis.py b/internal/vis.py new file mode 100644 index 0000000000000000000000000000000000000000..55e437a1338148be10cb80407b83567d3df43bd8 --- /dev/null +++ b/internal/vis.py @@ -0,0 +1,246 @@ +from internal import stepfun +import numpy as np +from matplotlib import cm + + +def weighted_percentile(x, w, ps, assume_sorted=False): + """Compute the weighted percentile(s) of a single vector.""" + if len(x.shape) != len(w.shape): + w = np.broadcast_to(w[..., None], x.shape) + x = x.reshape([-1]) + w = w.reshape([-1]) + if not assume_sorted: + sortidx = np.argsort(x) + x, w = x[sortidx], w[sortidx] + acc_w = np.cumsum(w) + return np.interp(np.array(ps) * (acc_w[-1] / 100), acc_w, x) + + +def sinebow(h): + """A cyclic and uniform colormap, see http://basecase.org/env/on-rainbows.""" + f = lambda x: np.sin(np.pi * x) ** 2 + return np.stack([f(3 / 6 - h), f(5 / 6 - h), f(7 / 6 - h)], -1) + + +def matte(vis, acc, dark=0.8, light=1.0, width=8): + """Set non-accumulated pixels to a Photoshop-esque checker pattern.""" + bg_mask = np.logical_xor( + (np.arange(acc.shape[0]) % (2 * width) // width)[:, None], + (np.arange(acc.shape[1]) % (2 * width) // width)[None, :]) + bg = np.where(bg_mask, light, dark) + return vis * acc[:, :, None] + (bg * (1 - acc))[:, :, None] + + +def visualize_cmap(value, + weight, + colormap, + lo=None, + hi=None, + percentile=99., + curve_fn=lambda x: x, + modulus=None, + matte_background=True): + """Visualize a 1D image and a 1D weighting according to some colormap. + + Args: + value: A 1D image. + weight: A weight map, in [0, 1]. + colormap: A colormap function. + lo: The lower bound to use when rendering, if None then use a percentile. + hi: The upper bound to use when rendering, if None then use a percentile. + percentile: What percentile of the value map to crop to when automatically + generating `lo` and `hi`. Depends on `weight` as well as `value'. + curve_fn: A curve function that gets applied to `value`, `lo`, and `hi` + before the rest of visualization. Good choices: x, 1/(x+eps), log(x+eps). + modulus: If not None, mod the normalized value by `modulus`. Use (0, 1]. If + `modulus` is not None, `lo`, `hi` and `percentile` will have no effect. + matte_background: If True, matte the image over a checkerboard. + + Returns: + A colormap rendering. + """ + # Identify the values that bound the middle of `value' according to `weight`. + lo_auto, hi_auto = weighted_percentile( + value, weight, [50 - percentile / 2, 50 + percentile / 2], assume_sorted=True) + + # If `lo` or `hi` are None, use the automatically-computed bounds above. + eps = np.finfo(np.float32).eps + lo = lo or (lo_auto - eps) + hi = hi or (hi_auto + eps) + + # Curve all values. + value, lo, hi = [curve_fn(x) for x in [value, lo, hi]] + + # Wrap the values around if requested. + if modulus: + value = np.mod(value, modulus) / modulus + else: + # Otherwise, just scale to [0, 1]. + value = np.nan_to_num( + np.clip((value - np.minimum(lo, hi)) / np.abs(hi - lo), 0, 1)) + + if colormap: + colorized = colormap(value)[:, :, :3] + else: + if len(value.shape) != 3: + raise ValueError(f'value must have 3 dims but has {len(value.shape)}') + if value.shape[-1] != 3: + raise ValueError( + f'value must have 3 channels but has {len(value.shape[-1])}') + colorized = value + + return matte(colorized, weight) if matte_background else colorized + + +def visualize_coord_mod(coords, acc): + """Visualize the coordinate of each point within its "cell".""" + return matte(((coords + 1) % 2) / 2, acc) + + +def visualize_rays(dist, + dist_range, + weights, + rgbs, + accumulate=False, + renormalize=False, + resolution=2048, + bg_color=0.8): + """Visualize a bundle of rays.""" + dist_vis = np.linspace(*dist_range, resolution + 1) + vis_rgb, vis_alpha = [], [] + for ds, ws, rs in zip(dist, weights, rgbs): + vis_rs, vis_ws = [], [] + for d, w, r in zip(ds, ws, rs): + if accumulate: + # Produce the accumulated color and weight at each point along the ray. + w_csum = np.cumsum(w, axis=0) + rw_csum = np.cumsum((r * w[:, None]), axis=0) + eps = np.finfo(np.float32).eps + r, w = (rw_csum + eps) / (w_csum[:, None] + 2 * eps), w_csum + vis_rs.append(stepfun.resample_np(dist_vis, d, r.T, use_avg=True).T) + vis_ws.append(stepfun.resample_np(dist_vis, d, w.T, use_avg=True).T) + vis_rgb.append(np.stack(vis_rs)) + vis_alpha.append(np.stack(vis_ws)) + vis_rgb = np.stack(vis_rgb, axis=1) + vis_alpha = np.stack(vis_alpha, axis=1) + + if renormalize: + # Scale the alphas so that the largest value is 1, for visualization. + vis_alpha /= np.maximum(np.finfo(np.float32).eps, np.max(vis_alpha)) + + if resolution > vis_rgb.shape[0]: + rep = resolution // (vis_rgb.shape[0] * vis_rgb.shape[1] + 1) + stride = rep * vis_rgb.shape[1] + + vis_rgb = np.tile(vis_rgb, (1, 1, rep, 1)).reshape((-1,) + vis_rgb.shape[2:]) + vis_alpha = np.tile(vis_alpha, (1, 1, rep)).reshape((-1,) + vis_alpha.shape[2:]) + + # Add a strip of background pixels after each set of levels of rays. + vis_rgb = vis_rgb.reshape((-1, stride) + vis_rgb.shape[1:]) + vis_alpha = vis_alpha.reshape((-1, stride) + vis_alpha.shape[1:]) + vis_rgb = np.concatenate([vis_rgb, np.zeros_like(vis_rgb[:, :1])], + axis=1).reshape((-1,) + vis_rgb.shape[2:]) + vis_alpha = np.concatenate( + [vis_alpha, np.zeros_like(vis_alpha[:, :1])], + axis=1).reshape((-1,) + vis_alpha.shape[2:]) + + # Matte the RGB image over the background. + vis = vis_rgb * vis_alpha[..., None] + (bg_color * (1 - vis_alpha))[..., None] + + # Remove the final row of background pixels. + vis = vis[:-1] + vis_alpha = vis_alpha[:-1] + return vis, vis_alpha + + +def visualize_suite(rendering, batch): + """A wrapper around other visualizations for easy integration.""" + + depth_curve_fn = lambda x: -np.log(x + np.finfo(np.float32).eps) + + rgb = rendering['rgb'] + acc = rendering['acc'] + + distance_mean = rendering['distance_mean'] + distance_median = rendering['distance_median'] + distance_p5 = rendering['distance_percentile_5'] + distance_p95 = rendering['distance_percentile_95'] + acc = np.where(np.isnan(distance_mean), np.zeros_like(acc), acc) + + # The xyz coordinates where rays terminate. + coords = batch['origins'] + batch['directions'] * distance_mean[:, :, None] + + vis_depth_mean, vis_depth_median = [ + visualize_cmap(x, acc, cm.get_cmap('turbo'), curve_fn=depth_curve_fn) + for x in [distance_mean, distance_median] + ] + + # Render three depth percentiles directly to RGB channels, where the spacing + # determines the color. delta == big change, epsilon = small change. + # Gray: A strong discontinuitiy, [x-epsilon, x, x+epsilon] + # Purple: A thin but even density, [x-delta, x, x+delta] + # Red: A thin density, then a thick density, [x-delta, x, x+epsilon] + # Blue: A thick density, then a thin density, [x-epsilon, x, x+delta] + vis_depth_triplet = visualize_cmap( + np.stack( + [2 * distance_median - distance_p5, distance_median, distance_p95], + axis=-1), + acc, + None, + curve_fn=lambda x: np.log(x + np.finfo(np.float32).eps)) + + dist = rendering['ray_sdist'] + dist_range = (0, 1) + weights = rendering['ray_weights'] + rgbs = [np.clip(r, 0, 1) for r in rendering['ray_rgbs']] + + vis_ray_colors, _ = visualize_rays(dist, dist_range, weights, rgbs) + + sqrt_weights = [np.sqrt(w) for w in weights] + sqrt_ray_weights, ray_alpha = visualize_rays( + dist, + dist_range, + [np.ones_like(lw) for lw in sqrt_weights], + [lw[..., None] for lw in sqrt_weights], + bg_color=0, + ) + sqrt_ray_weights = sqrt_ray_weights[..., 0] + + null_color = np.array([1., 0., 0.]) + vis_ray_weights = np.where( + ray_alpha[:, :, None] == 0, + null_color[None, None], + visualize_cmap( + sqrt_ray_weights, + np.ones_like(sqrt_ray_weights), + cm.get_cmap('gray'), + lo=0, + hi=1, + matte_background=False, + ), + ) + + vis = { + 'color': rgb, + 'acc': acc, + 'color_matte': matte(rgb, acc), + 'depth_mean': vis_depth_mean, + 'depth_median': vis_depth_median, + 'depth_triplet': vis_depth_triplet, + 'coords_mod': visualize_coord_mod(coords, acc), + 'ray_colors': vis_ray_colors, + 'ray_weights': vis_ray_weights, + } + + if 'rgb_cc' in rendering: + vis['color_corrected'] = rendering['rgb_cc'] + + # Render every item named "normals*". + for key, val in rendering.items(): + if key.startswith('normals'): + vis[key] = matte(val / 2. + 0.5, acc) + + if 'roughness' in rendering: + vis['roughness'] = matte(np.tanh(rendering['roughness']), acc) + + return vis diff --git a/render.py b/render.py new file mode 100644 index 0000000000000000000000000000000000000000..7f90383025244cb0fec8c6b12671da2eadbf2b20 --- /dev/null +++ b/render.py @@ -0,0 +1,172 @@ +import glob +import logging +import os +import sys +import time + +from absl import app +import gin +from internal import configs +from internal import datasets +from internal import models +from internal import train_utils +from internal import checkpoints +from internal import utils +from internal import vis +from matplotlib import cm +import mediapy as media +import torch +import numpy as np +import accelerate +import imageio +from torch.utils._pytree import tree_map + +configs.define_common_flags() + + +def create_videos(config, base_dir, out_dir, out_name, num_frames): + """Creates videos out of the images saved to disk.""" + names = [n for n in config.exp_path.split('/') if n] + # Last two parts of checkpoint path are experiment name and scene name. + exp_name, scene_name = names[-2:] + video_prefix = f'{scene_name}_{exp_name}_{out_name}' + + zpad = max(3, len(str(num_frames - 1))) + idx_to_str = lambda idx: str(idx).zfill(zpad) + + utils.makedirs(base_dir) + + # Load one example frame to get image shape and depth range. + depth_file = os.path.join(out_dir, f'distance_mean_{idx_to_str(0)}.tiff') + depth_frame = utils.load_img(depth_file) + shape = depth_frame.shape + p = config.render_dist_percentile + distance_limits = np.percentile(depth_frame.flatten(), [p, 100 - p]) + # lo, hi = [config.render_dist_curve_fn(x) for x in distance_limits] + depth_curve_fn = lambda x: -np.log(x + np.finfo(np.float32).eps) + lo, hi = distance_limits + print(f'Video shape is {shape[:2]}') + + for k in ['color', 'normals', 'acc', 'distance_mean', 'distance_median']: + video_file = os.path.join(base_dir, f'{video_prefix}_{k}.mp4') + file_ext = 'png' if k in ['color', 'normals'] else 'tiff' + file0 = os.path.join(out_dir, f'{k}_{idx_to_str(0)}.{file_ext}') + if not utils.file_exists(file0): + print(f'Images missing for tag {k}') + continue + print(f'Making video {video_file}...') + + writer = imageio.get_writer(video_file, fps=config.render_video_fps) + for idx in range(num_frames): + img_file = os.path.join(out_dir, f'{k}_{idx_to_str(idx)}.{file_ext}') + if not utils.file_exists(img_file): + ValueError(f'Image file {img_file} does not exist.') + + img = utils.load_img(img_file) + if k in ['color', 'normals']: + img = img / 255. + elif k.startswith('distance'): + # img = config.render_dist_curve_fn(img) + # img = np.clip((img - np.minimum(lo, hi)) / np.abs(hi - lo), 0, 1) + # img = cm.get_cmap('turbo')(img)[..., :3] + + img = vis.visualize_cmap(img, np.ones_like(img), cm.get_cmap('turbo'), lo, hi, curve_fn=depth_curve_fn) + + frame = (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8) + writer.append_data(frame) + writer.close() + + +def main(unused_argv): + config = configs.load_config() + config.exp_path = os.path.join('exp', config.exp_name) + config.checkpoint_dir = os.path.join(config.exp_path, 'checkpoints') + config.render_dir = os.path.join(config.exp_path, 'render') + + accelerator = accelerate.Accelerator() + # setup logger + logging.basicConfig( + format="%(asctime)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + force=True, + handlers=[logging.StreamHandler(sys.stdout), + logging.FileHandler(os.path.join(config.exp_path, 'log_render.txt'))], + level=logging.INFO, + ) + sys.excepthook = utils.handle_exception + logger = accelerate.logging.get_logger(__name__) + logger.info(config) + logger.info(accelerator.state, main_process_only=False) + + config.world_size = accelerator.num_processes + config.global_rank = accelerator.process_index + accelerate.utils.set_seed(config.seed, device_specific=True) + model = models.Model(config=config) + model.eval() + + dataset = datasets.load_dataset('test', config.data_dir, config) + dataloader = torch.utils.data.DataLoader(np.arange(len(dataset)), + shuffle=False, + batch_size=1, + collate_fn=dataset.collate_fn, + ) + dataiter = iter(dataloader) + if config.rawnerf_mode: + postprocess_fn = dataset.metadata['postprocess_fn'] + else: + postprocess_fn = lambda z: z + + model = accelerator.prepare(model) + step = checkpoints.restore_checkpoint(config.checkpoint_dir, accelerator, logger) + + logger.info(f'Rendering checkpoint at step {step}.') + + out_name = 'path_renders' if config.render_path else 'test_preds' + out_name = f'{out_name}_step_{step}2' + out_dir = os.path.join(config.render_dir, out_name) + utils.makedirs(out_dir) + + path_fn = lambda x: os.path.join(out_dir, x) + + # Ensure sufficient zero-padding of image indices in output filenames. + zpad = max(3, len(str(dataset.size - 1))) + idx_to_str = lambda idx: str(idx).zfill(zpad) + + for idx in range(dataset.size): + # If current image and next image both already exist, skip ahead. + idx_str = idx_to_str(idx) + curr_file = path_fn(f'color_{idx_str}.png') + if utils.file_exists(curr_file): + logger.info(f'Image {idx + 1}/{dataset.size} already exists, skipping') + continue + batch = next(dataiter) + batch = tree_map(lambda x: x.to(accelerator.device) if x is not None else None, batch) + logger.info(f'Evaluating image {idx + 1}/{dataset.size}') + eval_start_time = time.time() + rendering = models.render_image(model, accelerator, + batch, False, 1, config) + + logger.info(f'Rendered in {(time.time() - eval_start_time):0.3f}s') + + if accelerator.is_main_process: # Only record via host 0. + rendering['rgb'] = postprocess_fn(rendering['rgb']) + rendering = tree_map(lambda x: x.detach().cpu().numpy() if x is not None else None, rendering) + utils.save_img_u8(rendering['rgb'], path_fn(f'color_{idx_str}.png')) + if 'normals' in rendering: + utils.save_img_u8(rendering['normals'] / 2. + 0.5, + path_fn(f'normals_{idx_str}.png')) + utils.save_img_f32(rendering['distance_mean'], + path_fn(f'distance_mean_{idx_str}.tiff')) + utils.save_img_f32(rendering['distance_median'], + path_fn(f'distance_median_{idx_str}.tiff')) + utils.save_img_f32(rendering['acc'], path_fn(f'acc_{idx_str}.tiff')) + num_files = len(glob.glob(path_fn('acc_*.tiff'))) + if accelerator.is_main_process and num_files == dataset.size: + logger.info(f'All files found, creating videos.') + create_videos(config, config.render_dir, out_dir, out_name, dataset.size) + accelerator.wait_for_everyone() + logger.info('Finish rendering.') + +if __name__ == '__main__': + with gin.config_scope('eval'): # Use the same scope as eval.py + app.run(main) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..bb929b6d893665cd83e49a77708d017cc16f27d8 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,24 @@ +numpy +torch +absl_py +accelerate +gin_config +imageio +imageio[ffmpeg] +matplotlib +mediapy +opencv_contrib_python +opencv_python +Pillow +trimesh +pymeshlab +xatlas +plyfile +rawpy +ninja +scipy +scikit-image +scikit-learn +tensorboard +tensorboardX +tqdm diff --git a/scripts/concat_videos.py b/scripts/concat_videos.py new file mode 100644 index 0000000000000000000000000000000000000000..091df223962b400bb453ea524c42f09007e5f6ca --- /dev/null +++ b/scripts/concat_videos.py @@ -0,0 +1,26 @@ +import os +import imageio +import cv2 +import numpy as np +from tqdm import tqdm + +MAX_H, MAX_W = 512, 768 +exp_name = "360_v2_glo" +keys = ["color", "distance_mean"] +# keys = ["color"] + +os.makedirs('assets', exist_ok=True) +root = os.path.join("exp", exp_name) +scenes = sorted(os.listdir(root)) + +video_files = [[os.path.join(root, scene, "render", + f"{scene}_{exp_name}_path_renders_step_25000_{k}.mp4") + for k in keys] for scene in scenes] +video_files = [f for f in video_files if os.path.exists(f[0])] + +with imageio.get_writer(os.path.join("assets", exp_name+'.mp4'), fps=30) as writer: + for scene_videos in tqdm(video_files): + readers = [imageio.get_reader(f) for f in scene_videos] + for data in zip(*readers): + data = np.concatenate([cv2.resize(img, (MAX_W, MAX_H)) for img in data], axis=1) + writer.append_data(data) diff --git a/scripts/eval_360.sh b/scripts/eval_360.sh new file mode 100644 index 0000000000000000000000000000000000000000..c34c36edaec81fddd09017cd598386830f1411d3 --- /dev/null +++ b/scripts/eval_360.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +SCENE=bicycle +EXPERIMENT=360_v2/"$SCENE" +DATA_ROOT=/SSD_DISK/datasets/360_v2 +DATA_DIR="$DATA_ROOT"/"$SCENE" + +accelerate launch eval.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 4" diff --git a/scripts/eval_360_all.sh b/scripts/eval_360_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..4bf4a6f4707cfa851e456cef548ad1f1a9f39760 --- /dev/null +++ b/scripts/eval_360_all.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +EXPERIMENT_PREFIX=360_v2 +SCENE=("bicycle" "garden" "stump") +DATA_ROOT=/SSD_DISK/datasets/360_v2 + +len=${#SCENE[@]} +for(( i=0; i<$len; i++ )) +do + EXPERIMENT=$EXPERIMENT_PREFIX/"${SCENE[i]}" + DATA_DIR="$DATA_ROOT"/"${SCENE[i]}" + + accelerate launch eval.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 4" +done + +SCENE=("room" "counter" "kitchen" "bonsai") +len=${#SCENE[@]} +for(( i=0; i<$len; i++ )) +do + EXPERIMENT=$EXPERIMENT_PREFIX/"${SCENE[i]}" + DATA_DIR="$DATA_ROOT"/"${SCENE[i]}" + + accelerate launch eval.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 2" +done \ No newline at end of file diff --git a/scripts/extract_360.sh b/scripts/extract_360.sh new file mode 100644 index 0000000000000000000000000000000000000000..6f9cf81239392611bda9787957e16d93bef24f6a --- /dev/null +++ b/scripts/extract_360.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +SCENE=bicycle +EXPERIMENT=blender/"$SCENE" +DATA_ROOT=/SSD_DISK/datasets/360_v2 +DATA_DIR="$DATA_ROOT"/"$SCENE" + +accelerate launch extract.py --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 4" diff --git a/scripts/extract_360_all.sh b/scripts/extract_360_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..919c874ab84883eb63d259624fd28a2f4a84c964 --- /dev/null +++ b/scripts/extract_360_all.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +# outdoor +EXPERIMENT_PREFIX=360_v2_0527 +SCENE=("bicycle" "garden" "stump" ) +DATA_ROOT=/SSD_DISK/datasets/360_v2 + +len=${#SCENE[@]} +for(( i=0; i<$len; i++ )) +do + EXPERIMENT=$EXPERIMENT_PREFIX/"${SCENE[i]}" + DATA_DIR="$DATA_ROOT"/"${SCENE[i]}" + accelerate launch extract.py --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 4" +done + +# indoor "Config.factor = 2" +SCENE=("room" "counter" "kitchen" "bonsai") +len=${#SCENE[@]} +for((i=0; i<$len; i++ )) +do + EXPERIMENT=$EXPERIMENT_PREFIX/"${SCENE[i]}" + DATA_DIR="$DATA_ROOT"/"${SCENE[i]}" + + accelerate launch extract.py --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 2" +done \ No newline at end of file diff --git a/scripts/extract_360_tsdf.sh b/scripts/extract_360_tsdf.sh new file mode 100644 index 0000000000000000000000000000000000000000..e4d02e22da0f7be9f188e33cf6321682790de095 --- /dev/null +++ b/scripts/extract_360_tsdf.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +SCENE=bicycle +EXPERIMENT=blender/"$SCENE" +DATA_ROOT=/SSD_DISK/datasets/360_v2 +DATA_DIR="$DATA_ROOT"/"$SCENE" + +accelerate launch tsdf.py --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 4" diff --git a/scripts/extract_360_tsdf_all.sh b/scripts/extract_360_tsdf_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..2383826506dfaadea0b10b7d6a096e9863f2cd4a --- /dev/null +++ b/scripts/extract_360_tsdf_all.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +# outdoor +EXPERIMENT_PREFIX=360_v2_0527 +SCENE=("bicycle" "garden" "stump" ) +DATA_ROOT=/SSD_DISK/datasets/360_v2 + +len=${#SCENE[@]} +for(( i=0; i<$len; i++ )) +do + EXPERIMENT=$EXPERIMENT_PREFIX/"${SCENE[i]}" + DATA_DIR="$DATA_ROOT"/"${SCENE[i]}" + accelerate launch tsdf.py --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 4" +done + +# indoor "Config.factor = 2" +SCENE=("room" "counter" "kitchen" "bonsai") +len=${#SCENE[@]} +for((i=0; i<$len; i++ )) +do + EXPERIMENT=$EXPERIMENT_PREFIX/"${SCENE[i]}" + DATA_DIR="$DATA_ROOT"/"${SCENE[i]}" + + accelerate launch tsdf.py --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 2" +done \ No newline at end of file diff --git a/scripts/extract_blender_all.sh b/scripts/extract_blender_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..d6093771a67bd960b3146eed9d70c30aa281b0a6 --- /dev/null +++ b/scripts/extract_blender_all.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +EXPERIMENT_PREFIX=blender +SCENE=("drums" "ficus" "hotdog" "lego" "materials" "mic" "ship") +DATA_ROOT=/SSD_DISK/datasets/nerf_synthetic + +len=${#SCENE[@]} +for(( i=0; i<$len; i++ )) +do + EXPERIMENT=$EXPERIMENT_PREFIX/"${SCENE[i]}" + DATA_DIR="$DATA_ROOT"/"${SCENE[i]}" + accelerate launch extract.py --gin_configs=configs/blender.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" +done diff --git a/scripts/local_colmap_and_resize.sh b/scripts/local_colmap_and_resize.sh new file mode 100644 index 0000000000000000000000000000000000000000..5e6d718e50b2eb5b6c2347a11cbb10264ad477fc --- /dev/null +++ b/scripts/local_colmap_and_resize.sh @@ -0,0 +1,80 @@ +#!/bin/bash + + +# Set to 0 if you do not have a GPU. +USE_GPU=0 +# Path to a directory `base/` with images in `base/images/`. +DATASET_PATH=$1 +# Recommended CAMERA values: OPENCV for perspective, OPENCV_FISHEYE for fisheye. +CAMERA=${2:-OPENCV} + + +# Run COLMAP. + +### Feature extraction + +colmap feature_extractor \ + --database_path "$DATASET_PATH"/database.db \ + --image_path "$DATASET_PATH"/images \ + --ImageReader.single_camera 1 \ + --ImageReader.camera_model "$CAMERA" \ + --SiftExtraction.use_gpu "$USE_GPU" + + +### Feature matching + +colmap exhaustive_matcher \ + --database_path "$DATASET_PATH"/database.db \ + --SiftMatching.use_gpu "$USE_GPU" + +## Use if your scene has > 500 images +## Replace this path with your own local copy of the file. +## Download from: https://demuc.de/colmap/#download +# VOCABTREE_PATH=/usr/local/google/home/bmild/vocab_tree_flickr100K_words32K.bin +# colmap vocab_tree_matcher \ +# --database_path "$DATASET_PATH"/database.db \ +# --VocabTreeMatching.vocab_tree_path $VOCABTREE_PATH \ +# --SiftMatching.use_gpu "$USE_GPU" + + +### Bundle adjustment + +# The default Mapper tolerance is unnecessarily large, +# decreasing it speeds up bundle adjustment steps. +mkdir -p "$DATASET_PATH"/sparse +colmap mapper \ + --database_path "$DATASET_PATH"/database.db \ + --image_path "$DATASET_PATH"/images \ + --output_path "$DATASET_PATH"/sparse \ + --Mapper.ba_global_function_tolerance=0.000001 + + +### Image undistortion + +## Use this if you want to undistort your images into ideal pinhole intrinsics. +# mkdir -p "$DATASET_PATH"/dense +# colmap image_undistorter \ +# --image_path "$DATASET_PATH"/images \ +# --input_path "$DATASET_PATH"/sparse/0 \ +# --output_path "$DATASET_PATH"/dense \ +# --output_type COLMAP + +# Resize images. + +cp -r "$DATASET_PATH"/images "$DATASET_PATH"/images_2 + +pushd "$DATASET_PATH"/images_2 +ls | xargs -P 8 -I {} mogrify -resize 50% {} +popd + +cp -r "$DATASET_PATH"/images "$DATASET_PATH"/images_4 + +pushd "$DATASET_PATH"/images_4 +ls | xargs -P 8 -I {} mogrify -resize 25% {} +popd + +cp -r "$DATASET_PATH"/images "$DATASET_PATH"/images_8 + +pushd "$DATASET_PATH"/images_8 +ls | xargs -P 8 -I {} mogrify -resize 12.5% {} +popd \ No newline at end of file diff --git a/scripts/render_360.sh b/scripts/render_360.sh new file mode 100644 index 0000000000000000000000000000000000000000..0127f0e0d0de010619ed521baa965c6fff714be9 --- /dev/null +++ b/scripts/render_360.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +SCENE=bicycle +EXPERIMENT=360_v2/"$SCENE" +DATA_ROOT=/SSD_DISK/datasets/360_v2 +DATA_DIR="$DATA_ROOT"/"$SCENE" + +accelerate launch render.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.render_path = True" \ + --gin_bindings="Config.render_path_frames = 120" \ + --gin_bindings="Config.render_video_fps = 30" \ + --gin_bindings="Config.factor = 4" diff --git a/scripts/render_360_all.sh b/scripts/render_360_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..60d2a1da6b13751cec254f0aa8662e7aa930f5c7 --- /dev/null +++ b/scripts/render_360_all.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +EXPERIMENT_PREFIX=360_v2 +SCENE=("bicycle" "garden" "stump") +DATA_ROOT=/SSD_DISK/datasets/360_v2 + +len=${#SCENE[@]} +for(( i=0; i<$len; i++ )) +do + EXPERIMENT=$EXPERIMENT_PREFIX/"${SCENE[i]}" + DATA_DIR="$DATA_ROOT"/"${SCENE[i]}" + accelerate launch render.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.render_path = True" \ + --gin_bindings="Config.render_path_frames = 120" \ + --gin_bindings="Config.render_video_fps = 30" \ + --gin_bindings="Config.factor = 4" +done + + +SCENE=("room" "counter" "kitchen" "bonsai") +len=${#SCENE[@]} +for(( i=0; i<$len; i++ )) +do + EXPERIMENT=$EXPERIMENT_PREFIX/"${SCENE[i]}" + DATA_DIR="$DATA_ROOT"/"${SCENE[i]}" + accelerate launch render.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.render_path = True" \ + --gin_bindings="Config.render_path_frames = 120" \ + --gin_bindings="Config.render_video_fps = 30" \ + --gin_bindings="Config.factor = 2" +done diff --git a/scripts/train_360.sh b/scripts/train_360.sh new file mode 100644 index 0000000000000000000000000000000000000000..81270d51e9e76b470b1f641a1f77f8c97c9a9064 --- /dev/null +++ b/scripts/train_360.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +SCENE=bicycle +EXPERIMENT=360_v2/"$SCENE" +DATA_ROOT=/SSD_DISK/datasets/360_v2 +DATA_DIR="$DATA_ROOT"/"$SCENE" + +rm exp/"$EXPERIMENT"/* +accelerate launch train.py --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 4" diff --git a/scripts/train_360_all.sh b/scripts/train_360_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..53921bbbd036d68a473b6fb5735925612a5d94cb --- /dev/null +++ b/scripts/train_360_all.sh @@ -0,0 +1,81 @@ +#!/bin/bash + +# outdoor +EXPERIMENT_PREFIX=360_v2_gradient_scaling +SCENE=("bicycle" "garden" "stump" ) +#SCENE=("garden" ) +#DATA_ROOT=/SSD_DISK/datasets/360_v2 +DATA_ROOT=/SSD_DISK/users/guchun/speed/zipnerf-pytorch/data/360_v2 + +len=${#SCENE[@]} +for((i=0; i<$len; i++ )) +do + EXPERIMENT=$EXPERIMENT_PREFIX/"${SCENE[i]}" + DATA_DIR="$DATA_ROOT"/"${SCENE[i]}" + + accelerate launch train.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.gradient_scaling = True" \ + --gin_bindings="Config.factor = 4" + + accelerate launch eval.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 4" + + accelerate launch render.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.render_path = True" \ + --gin_bindings="Config.render_path_frames = 120" \ + --gin_bindings="Config.render_video_fps = 30" \ + --gin_bindings="Config.factor = 4" + + accelerate launch extract.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 4" +done + +# indoor "Config.factor = 2" +SCENE=("room" "counter" "kitchen" "bonsai") +SCENE=() +len=${#SCENE[@]} +for((i=0; i<$len; i++ )) +do + EXPERIMENT=$EXPERIMENT_PREFIX/"${SCENE[i]}" + DATA_DIR="$DATA_ROOT"/"${SCENE[i]}" + + accelerate launch train.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.gradient_scaling = True" \ + --gin_bindings="Config.factor = 2" + + accelerate launch eval.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 2" + + accelerate launch render.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.render_path = True" \ + --gin_bindings="Config.render_path_frames = 120" \ + --gin_bindings="Config.render_video_fps = 30" \ + --gin_bindings="Config.factor = 2" + + accelerate launch extract.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 2" +done \ No newline at end of file diff --git a/scripts/train_360_glo.sh b/scripts/train_360_glo.sh new file mode 100644 index 0000000000000000000000000000000000000000..cfdbc57dc2036c0c223e4991f0444ff74d524d5f --- /dev/null +++ b/scripts/train_360_glo.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +SCENE=bicycle +EXPERIMENT=360_v2_glo/"$SCENE" +DATA_ROOT=/SSD_DISK/datasets/360_v2 +DATA_DIR="$DATA_ROOT"/"$SCENE" + +rm exp/"$EXPERIMENT"/* +python train.py --gin_configs=configs/360_glo.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 4" + diff --git a/scripts/train_360_glo_all.sh b/scripts/train_360_glo_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..122f9e6ce08500980ba55e7acde45f44665685c2 --- /dev/null +++ b/scripts/train_360_glo_all.sh @@ -0,0 +1,78 @@ +#!/bin/bash + +# outdoor +EXPERIMENT_PREFIX=360_v2_glo +SCENE=("bicycle" "garden" "stump" ) +DATA_ROOT=/SSD_DISK/datasets/360_v2 + +len=${#SCENE[@]} +for((i=0; i<$len; i++ )) +do + EXPERIMENT=$EXPERIMENT_PREFIX/"${SCENE[i]}" + DATA_DIR="$DATA_ROOT"/"${SCENE[i]}" + + accelerate launch train.py \ + --gin_configs=configs/360_glo.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 4" + + accelerate launch eval.py \ + --gin_configs=configs/360_glo.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 4" + + accelerate launch render.py \ + --gin_configs=configs/360_glo.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.render_path = True" \ + --gin_bindings="Config.render_path_frames = 120" \ + --gin_bindings="Config.render_video_fps = 30" \ + --gin_bindings="Config.factor = 4" + + accelerate launch extract.py \ + --gin_configs=configs/360_glo.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 4" \ + --gin_bindings="Config.vertex_projection = True" +done + +# indoor "Config.factor = 2" +SCENE=("room" "counter" "kitchen" "bonsai") +len=${#SCENE[@]} +for((i=0; i<$len; i++ )) +do + EXPERIMENT=$EXPERIMENT_PREFIX/"${SCENE[i]}" + DATA_DIR="$DATA_ROOT"/"${SCENE[i]}" + + accelerate launch train.py \ + --gin_configs=configs/360_glo.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 2" + + accelerate launch eval.py \ + --gin_configs=configs/360_glo.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 2" + + accelerate launch render.py \ + --gin_configs=configs/360_glo.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.render_path = True" \ + --gin_bindings="Config.render_path_frames = 120" \ + --gin_bindings="Config.render_video_fps = 30" \ + --gin_bindings="Config.factor = 2" + + accelerate launch extract.py \ + --gin_configs=configs/360_glo.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 2" \ + --gin_bindings="Config.vertex_projection = True" +done \ No newline at end of file diff --git a/scripts/train_blender.sh b/scripts/train_blender.sh new file mode 100644 index 0000000000000000000000000000000000000000..a3e7a7e19f17e2b8b6fac46f8a89cebfde5a5a2e --- /dev/null +++ b/scripts/train_blender.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +SCENE=chair +EXPERIMENT=blender/"$SCENE" +DATA_ROOT=/SSD_DISK/datasets/nerf_synthetic +DATA_DIR="$DATA_ROOT"/"$SCENE" + +rm exp/"$EXPERIMENT"/* +accelerate launch train.py --gin_configs=configs/blender.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" diff --git a/scripts/train_blender_all.sh b/scripts/train_blender_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..2ccad7af332c20a1987e696829cdb2ba666cdfcf --- /dev/null +++ b/scripts/train_blender_all.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +EXPERIMENT_PREFIX=blender +SCENE=("drums" "ficus" "hotdog" "lego" "materials" "mic" "ship") +DATA_ROOT=/SSD_DISK/datasets/nerf_synthetic + +len=${#SCENE[@]} +for((i=0; i<$len; i++ )) +do + EXPERIMENT=$EXPERIMENT_PREFIX/"${SCENE[i]}" + DATA_DIR="$DATA_ROOT"/"${SCENE[i]}" + + accelerate launch train.py \ + --gin_configs=configs/blender.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" + + accelerate launch eval.py \ + --gin_configs=configs/blender.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" + + accelerate launch extract.py \ + --gin_configs=configs/blender.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.mesh_voxels = 1073741824" # 1024 ** 3 +done \ No newline at end of file diff --git a/scripts/train_multi360.sh b/scripts/train_multi360.sh new file mode 100644 index 0000000000000000000000000000000000000000..9d0a0966195a6fcb8751578bc80428e50c135ca2 --- /dev/null +++ b/scripts/train_multi360.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +SCENE=bicycle +EXPERIMENT=360_v2_multiscale/"$SCENE" +DATA_ROOT=/SSD_DISK/datasets/360_v2 +DATA_DIR="$DATA_ROOT"/"$SCENE" + +rm exp/"$EXPERIMENT"/* +accelerate launch train.py --gin_configs=configs/multi360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 4" diff --git a/scripts/train_multi360_all.sh b/scripts/train_multi360_all.sh new file mode 100644 index 0000000000000000000000000000000000000000..07554e1804d3a2c0b91a24792b08e398efb16b67 --- /dev/null +++ b/scripts/train_multi360_all.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +# outdoor +EXPERIMENT_PREFIX=360_v2_multiscale +SCENE=("bicycle" "garden" "stump" ) +DATA_ROOT=/SSD_DISK/datasets/360_v2 + +len=${#SCENE[@]} +for((i=0; i<$len; i++ )) +do + EXPERIMENT=$EXPERIMENT_PREFIX/"${SCENE[i]}" + DATA_DIR="$DATA_ROOT"/"${SCENE[i]}" + + accelerate launch train.py \ + --gin_configs=configs/multi360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 4" + + accelerate launch eval.py \ + --gin_configs=configs/multi360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 4" + + accelerate launch render.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.render_path = True" \ + --gin_bindings="Config.render_path_frames = 120" \ + --gin_bindings="Config.render_video_fps = 30" \ + --gin_bindings="Config.factor = 4" +done + +# indoor "Config.factor = 2" +SCENE=("room" "counter" "kitchen" "bonsai") +len=${#SCENE[@]} +for((i=0; i<$len; i++ )) +do + EXPERIMENT=$EXPERIMENT_PREFIX/"${SCENE[i]}" + DATA_DIR="$DATA_ROOT"/"${SCENE[i]}" + + accelerate launch train.py \ + --gin_configs=configs/multi360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 2" + + accelerate launch eval.py \ + --gin_configs=configs/multi360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.factor = 2" + + accelerate launch render.py \ + --gin_configs=configs/360.gin \ + --gin_bindings="Config.data_dir = '${DATA_DIR}'" \ + --gin_bindings="Config.exp_name = '${EXPERIMENT}'" \ + --gin_bindings="Config.render_path = True" \ + --gin_bindings="Config.render_path_frames = 120" \ + --gin_bindings="Config.render_video_fps = 30" \ + --gin_bindings="Config.factor = 2" +done \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..349b8464ece59f99789c4d0313129394b3d02791 --- /dev/null +++ b/train.py @@ -0,0 +1,387 @@ +import glob +import logging +import os +import shutil +import sys + +import numpy as np +import random + +import time + +from absl import app +import gin +from internal import configs +from internal import datasets +from internal import image +from internal import models +from internal import train_utils +from internal import utils +from internal import vis +from internal import checkpoints +import torch +import accelerate +import tensorboardX +from tqdm import tqdm +from tqdm.contrib.logging import logging_redirect_tqdm +from torch.utils._pytree import tree_map + +configs.define_common_flags() + +TIME_PRECISION = 1000 # Internally represent integer times in milliseconds. + + + +def main(unused_argv): + config = configs.load_config() + config.exp_path = os.path.join("exp", config.exp_name) + config.checkpoint_dir = os.path.join(config.exp_path, 'checkpoints') + utils.makedirs(config.exp_path) + with utils.open_file(os.path.join(config.exp_path, 'config.gin'), 'w') as f: + f.write(gin.config_str()) + + # accelerator for DDP + accelerator = accelerate.Accelerator() + + # setup logger + logging.basicConfig( + format="%(asctime)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + force=True, + handlers=[logging.StreamHandler(sys.stdout), + logging.FileHandler(os.path.join(config.exp_path, 'log_train.txt'))], + level=logging.INFO, + ) + sys.excepthook = utils.handle_exception + logger = accelerate.logging.get_logger(__name__) + logger.info(config) + logger.info(accelerator.state, main_process_only=False) + + config.world_size = accelerator.num_processes + config.global_rank = accelerator.process_index + if config.batch_size % accelerator.num_processes != 0: + config.batch_size -= config.batch_size % accelerator.num_processes != 0 + logger.info('turn batch size to', config.batch_size) + + # Set random seed. + accelerate.utils.set_seed(config.seed, device_specific=True) + # setup model and optimizer + model = models.Model(config=config) + optimizer, lr_fn = train_utils.create_optimizer(config, model) + + # load dataset + dataset = datasets.load_dataset('train', config.data_dir, config) + test_dataset = datasets.load_dataset('test', config.data_dir, config) + dataloader = torch.utils.data.DataLoader(np.arange(len(dataset)), + num_workers=8, + shuffle=True, + batch_size=1, + collate_fn=dataset.collate_fn, + persistent_workers=True, + ) + test_dataloader = torch.utils.data.DataLoader(np.arange(len(test_dataset)), + num_workers=4, + shuffle=False, + batch_size=1, + persistent_workers=True, + collate_fn=test_dataset.collate_fn, + ) + if config.rawnerf_mode: + postprocess_fn = test_dataset.metadata['postprocess_fn'] + else: + postprocess_fn = lambda z, _=None: z + + # use accelerate to prepare. + model, dataloader, optimizer = accelerator.prepare(model, dataloader, optimizer) + + if config.resume_from_checkpoint: + init_step = checkpoints.restore_checkpoint(config.checkpoint_dir, accelerator, logger) + else: + init_step = 0 + + module = accelerator.unwrap_model(model) + dataiter = iter(dataloader) + test_dataiter = iter(test_dataloader) + + num_params = train_utils.tree_len(list(model.parameters())) + logger.info(f'Number of parameters being optimized: {num_params}') + + if (dataset.size > module.num_glo_embeddings and module.num_glo_features > 0): + raise ValueError(f'Number of glo embeddings {module.num_glo_embeddings} ' + f'must be at least equal to number of train images ' + f'{dataset.size}') + + # metric handler + metric_harness = image.MetricHarness() + + # tensorboard + if accelerator.is_main_process: + summary_writer = tensorboardX.SummaryWriter(config.exp_path) + # function to convert image for tensorboard + tb_process_fn = lambda x: x.transpose(2, 0, 1) if len(x.shape) == 3 else x[None] + + if config.rawnerf_mode: + for name, data in zip(['train', 'test'], [dataset, test_dataset]): + # Log shutter speed metadata in TensorBoard for debug purposes. + for key in ['exposure_idx', 'exposure_values', 'unique_shutters']: + summary_writer.add_text(f'{name}_{key}', str(data.metadata[key]), 0) + logger.info("Begin training...") + step = init_step + 1 + total_time = 0 + total_steps = 0 + reset_stats = True + if config.early_exit_steps is not None: + num_steps = config.early_exit_steps + else: + num_steps = config.max_steps + init_step = 0 + with logging_redirect_tqdm(): + tbar = tqdm(range(init_step + 1, num_steps + 1), + desc='Training', initial=init_step, total=num_steps, + disable=not accelerator.is_main_process) + for step in tbar: + try: + batch = next(dataiter) + except StopIteration: + dataiter = iter(dataloader) + batch = next(dataiter) + batch = accelerate.utils.send_to_device(batch, accelerator.device) + if reset_stats and accelerator.is_main_process: + stats_buffer = [] + train_start_time = time.time() + reset_stats = False + + # use lr_fn to control learning rate + learning_rate = lr_fn(step) + for param_group in optimizer.param_groups: + param_group['lr'] = learning_rate + + # fraction of training period + train_frac = np.clip((step - 1) / (config.max_steps - 1), 0, 1) + + # Indicates whether we need to compute output normal or depth maps in 2D. + compute_extras = (config.compute_disp_metrics or config.compute_normal_metrics) + optimizer.zero_grad() + with accelerator.autocast(): + renderings, ray_history = model( + True, + batch, + train_frac=train_frac, + compute_extras=compute_extras, + zero_glo=False) + + losses = {} + + # supervised by data + data_loss, stats = train_utils.compute_data_loss(batch, renderings, config) + losses['data'] = data_loss + + # interlevel loss in MipNeRF360 + if config.interlevel_loss_mult > 0 and not module.single_mlp: + losses['interlevel'] = train_utils.interlevel_loss(ray_history, config) + + # interlevel loss in ZipNeRF360 + if config.anti_interlevel_loss_mult > 0 and not module.single_mlp: + losses['anti_interlevel'] = train_utils.anti_interlevel_loss(ray_history, config) + + # distortion loss + if config.distortion_loss_mult > 0: + losses['distortion'] = train_utils.distortion_loss(ray_history, config) + + # opacity loss + if config.opacity_loss_mult > 0: + losses['opacity'] = train_utils.opacity_loss(renderings, config) + + # orientation loss in RefNeRF + if (config.orientation_coarse_loss_mult > 0 or + config.orientation_loss_mult > 0): + losses['orientation'] = train_utils.orientation_loss(batch, module, ray_history, + config) + # hash grid l2 weight decay + if config.hash_decay_mults > 0: + losses['hash_decay'] = train_utils.hash_decay_loss(ray_history, config) + + # normal supervision loss in RefNeRF + if (config.predicted_normal_coarse_loss_mult > 0 or + config.predicted_normal_loss_mult > 0): + losses['predicted_normals'] = train_utils.predicted_normal_loss( + module, ray_history, config) + loss = sum(losses.values()) + stats['loss'] = loss.item() + stats['losses'] = tree_map(lambda x: x.item(), losses) + + # accelerator automatically handle the scale + accelerator.backward(loss) + # clip gradient by max/norm/nan + train_utils.clip_gradients(model, accelerator, config) + optimizer.step() + + stats['psnrs'] = image.mse_to_psnr(stats['mses']) + stats['psnr'] = stats['psnrs'][-1] + + # Log training summaries. This is put behind a host_id check because in + # multi-host evaluation, all hosts need to run inference even though we + # only use host 0 to record results. + if accelerator.is_main_process: + stats_buffer.append(stats) + if step == init_step + 1 or step % config.print_every == 0: + elapsed_time = time.time() - train_start_time + steps_per_sec = config.print_every / elapsed_time + rays_per_sec = config.batch_size * steps_per_sec + + # A robust approximation of total training time, in case of pre-emption. + total_time += int(round(TIME_PRECISION * elapsed_time)) + total_steps += config.print_every + approx_total_time = int(round(step * total_time / total_steps)) + + # Transpose and stack stats_buffer along axis 0. + fs = [utils.flatten_dict(s, sep='/') for s in stats_buffer] + stats_stacked = {k: np.stack([f[k] for f in fs]) for k in fs[0].keys()} + + # Split every statistic that isn't a vector into a set of statistics. + stats_split = {} + for k, v in stats_stacked.items(): + if v.ndim not in [1, 2] and v.shape[0] != len(stats_buffer): + raise ValueError('statistics must be of size [n], or [n, k].') + if v.ndim == 1: + stats_split[k] = v + elif v.ndim == 2: + for i, vi in enumerate(tuple(v.T)): + stats_split[f'{k}/{i}'] = vi + + # Summarize the entire histogram of each statistic. + for k, v in stats_split.items(): + summary_writer.add_histogram('train_' + k, v, step) + + # Take the mean and max of each statistic since the last summary. + avg_stats = {k: np.mean(v) for k, v in stats_split.items()} + max_stats = {k: np.max(v) for k, v in stats_split.items()} + + summ_fn = lambda s, v: summary_writer.add_scalar(s, v, step) # pylint:disable=cell-var-from-loop + + # Summarize the mean and max of each statistic. + for k, v in avg_stats.items(): + summ_fn(f'train_avg_{k}', v) + for k, v in max_stats.items(): + summ_fn(f'train_max_{k}', v) + + summ_fn('train_num_params', num_params) + summ_fn('train_learning_rate', learning_rate) + summ_fn('train_steps_per_sec', steps_per_sec) + summ_fn('train_rays_per_sec', rays_per_sec) + + summary_writer.add_scalar('train_avg_psnr_timed', avg_stats['psnr'], + total_time // TIME_PRECISION) + summary_writer.add_scalar('train_avg_psnr_timed_approx', avg_stats['psnr'], + approx_total_time // TIME_PRECISION) + + if dataset.metadata is not None and module.learned_exposure_scaling: + scalings = module.exposure_scaling_offsets.weight + num_shutter_speeds = dataset.metadata['unique_shutters'].shape[0] + for i_s in range(num_shutter_speeds): + for j_s, value in enumerate(scalings[i_s]): + summary_name = f'exposure/scaling_{i_s}_{j_s}' + summary_writer.add_scalar(summary_name, value, step) + + precision = int(np.ceil(np.log10(config.max_steps))) + 1 + avg_loss = avg_stats['loss'] + avg_psnr = avg_stats['psnr'] + str_losses = { # Grab each "losses_{x}" field and print it as "x[:4]". + k[7:11]: (f'{v:0.5f}' if 1e-4 <= v < 10 else f'{v:0.1e}') + for k, v in avg_stats.items() + if k.startswith('losses/') + } + logger.info(f'{step}' + f'/{config.max_steps:d}:' + + f'loss={avg_loss:0.5f},' + f'psnr={avg_psnr:.3f},' + + f'lr={learning_rate:0.2e} | ' + + ','.join([f'{k}={s}' for k, s in str_losses.items()]) + + f',{rays_per_sec:0.0f} r/s') + + # Reset everything we are tracking between summarizations. + reset_stats = True + + if step > 0 and step % config.checkpoint_every == 0 and accelerator.is_main_process: + checkpoints.save_checkpoint(config.checkpoint_dir, + accelerator, step, + config.checkpoints_total_limit) + + # Test-set evaluation. + if config.train_render_every > 0 and step % config.train_render_every == 0: + # We reuse the same random number generator from the optimization step + # here on purpose so that the visualization matches what happened in + # training. + eval_start_time = time.time() + try: + test_batch = next(test_dataiter) + except StopIteration: + test_dataiter = iter(test_dataloader) + test_batch = next(test_dataiter) + test_batch = accelerate.utils.send_to_device(test_batch, accelerator.device) + + # render a single image with all distributed processes + rendering = models.render_image(model, accelerator, + test_batch, False, + train_frac, config) + + # move to numpy + rendering = tree_map(lambda x: x.detach().cpu().numpy(), rendering) + test_batch = tree_map(lambda x: x.detach().cpu().numpy() if x is not None else None, test_batch) + # Log eval summaries on host 0. + if accelerator.is_main_process: + eval_time = time.time() - eval_start_time + num_rays = np.prod(test_batch['directions'].shape[:-1]) + rays_per_sec = num_rays / eval_time + summary_writer.add_scalar('test_rays_per_sec', rays_per_sec, step) + + metric_start_time = time.time() + metric = metric_harness( + postprocess_fn(rendering['rgb']), postprocess_fn(test_batch['rgb'])) + logger.info(f'Eval {step}: {eval_time:0.3f}s, {rays_per_sec:0.0f} rays/sec') + logger.info(f'Metrics computed in {(time.time() - metric_start_time):0.3f}s') + for name, val in metric.items(): + if not np.isnan(val): + logger.info(f'{name} = {val:.4f}') + summary_writer.add_scalar('train_metrics/' + name, val, step) + + if config.vis_decimate > 1: + d = config.vis_decimate + decimate_fn = lambda x, d=d: None if x is None else x[::d, ::d] + else: + decimate_fn = lambda x: x + rendering = tree_map(decimate_fn, rendering) + test_batch = tree_map(decimate_fn, test_batch) + vis_start_time = time.time() + vis_suite = vis.visualize_suite(rendering, test_batch) + with tqdm.external_write_mode(): + logger.info(f'Visualized in {(time.time() - vis_start_time):0.3f}s') + if config.rawnerf_mode: + # Unprocess raw output. + vis_suite['color_raw'] = rendering['rgb'] + # Autoexposed colors. + vis_suite['color_auto'] = postprocess_fn(rendering['rgb'], None) + summary_writer.add_image('test_true_auto', + tb_process_fn(postprocess_fn(test_batch['rgb'], None)), step) + # Exposure sweep colors. + exposures = test_dataset.metadata['exposure_levels'] + for p, x in list(exposures.items()): + vis_suite[f'color/{p}'] = postprocess_fn(rendering['rgb'], x) + summary_writer.add_image(f'test_true_color/{p}', + tb_process_fn(postprocess_fn(test_batch['rgb'], x)), step) + summary_writer.add_image('test_true_color', tb_process_fn(test_batch['rgb']), step) + if config.compute_normal_metrics: + summary_writer.add_image('test_true_normals', + tb_process_fn(test_batch['normals']) / 2. + 0.5, step) + for k, v in vis_suite.items(): + summary_writer.add_image('test_output_' + k, tb_process_fn(v), step) + + if accelerator.is_main_process and config.max_steps > init_step: + logger.info('Saving last checkpoint at step {} to {}'.format(step, config.checkpoint_dir)) + checkpoints.save_checkpoint(config.checkpoint_dir, + accelerator, step, + config.checkpoints_total_limit) + logger.info('Finish training.') + + +if __name__ == '__main__': + with gin.config_scope('train'): + app.run(main) diff --git a/tsdf.py b/tsdf.py new file mode 100644 index 0000000000000000000000000000000000000000..8343eb1c567e307d8b7a4a38b39ee60e9b48c474 --- /dev/null +++ b/tsdf.py @@ -0,0 +1,350 @@ +import glob +import logging +import os +import sys +import time + +import cv2 +import numpy as np +from absl import app +import gin +from internal import configs +from internal import datasets +from internal import models +from internal import utils +from internal import coord +from internal import checkpoints +from internal import configs +import torch +import accelerate +from tqdm import tqdm +from torch.utils._pytree import tree_map +import torch.nn.functional as F +from skimage import measure +import trimesh +import pymeshlab as pml +from torch import Tensor + +configs.define_common_flags() + + +class TSDF: + def __init__(self, config: configs.Config, accelerator: accelerate.Accelerator): + self.config = config + self.device = accelerator.device + self.accelerator = accelerator + self.origin = torch.tensor([-config.tsdf_radius] * 3, dtype=torch.float32, device=self.device) + self.voxel_size = 2 * config.tsdf_radius / (config.tsdf_resolution - 1) + self.resolution = config.tsdf_resolution + # create the voxel coordinates + dim = torch.arange(self.resolution) + grid = torch.stack(torch.meshgrid(dim, dim, dim, indexing="ij"), dim=0).reshape(3, -1) + period = int(grid.shape[1] / accelerator.num_processes + 0.5) + grid = grid[:, period * accelerator.process_index: period * (accelerator.process_index + 1)] + self.voxel_coords = self.origin.view(3, 1) + grid.to(self.device) * self.voxel_size + + N = self.voxel_coords.shape[1] + # make voxel_coords homogeneous + voxel_world_coords = coord.inv_contract(self.voxel_coords.permute(1, 0)).permute(1, 0).view(3, -1) + # voxel_world_coords = self.voxel_coords.view(3, -1) + voxel_world_coords = torch.cat( + [voxel_world_coords, torch.ones(1, voxel_world_coords.shape[1], device=self.device)], dim=0 + ) + voxel_world_coords = voxel_world_coords.unsqueeze(0) # [1, 4, N] + self.voxel_world_coords = voxel_world_coords.expand(-1, *voxel_world_coords.shape[1:]) # [1, 4, N] + + # initialize the values and weights + self.values = torch.ones(N, dtype=torch.float32, + device=self.device) + self.weights = torch.zeros(N, dtype=torch.float32, + device=self.device) + self.colors = torch.zeros(N, 3, dtype=torch.float32, + device=self.device) + + @property + def truncation(self): + """Returns the truncation distance.""" + # TODO: clean this up + truncation = self.voxel_size * self.config.truncation_margin + return truncation + + def export_mesh(self, path): + """Extracts a mesh using marching cubes.""" + # run marching cubes on CPU + tsdf_values = self.values.clamp(-1, 1) + mask = self.voxel_world_coords[:, :3].permute(0, 2, 1).norm(p=2, dim=-1) > self.config.tsdf_max_radius + tsdf_values[mask.reshape(self.values.shape)] = 1. + + tsdf_values_np = self.accelerator.gather(tsdf_values).cpu().reshape((self.resolution, self.resolution, self.resolution)).numpy() + color_values_np = self.accelerator.gather(self.colors).cpu().reshape((self.resolution, self.resolution, self.resolution, 3)).numpy() + + # # for OOM(resolution > 512) + # tsdf_values_np = tsdf_values.cpu().numpy() + # color_values_np = self.colors.cpu().numpy() + # path_dir = os.path.dirname(path) + # np.save(os.path.join(path_dir, 'tsdf_values_tmp_{}.npy'.format(self.accelerator.process_index)), tsdf_values_np) + # np.save(os.path.join(path_dir, 'color_values_tmp_{}.npy'.format(self.accelerator.process_index)), color_values_np) + # self.accelerator.wait_for_everyone() + + if self.accelerator.is_main_process: + # print('Start marching cubes') + # tsdf_values_np = np.concatenate([np.load(os.path.join(path_dir, 'tsdf_values_tmp_{}.npy'.format(i)), allow_pickle=True) for i in + # range(self.accelerator.num_processes)]).reshape((self.resolution, self.resolution, self.resolution)) + # color_values_np = np.concatenate([np.load(os.path.join(path_dir, 'color_values_tmp_{}.npy'.format(i)), allow_pickle=True) for i in + # range(self.accelerator.num_processes)]).reshape((self.resolution, self.resolution, self.resolution, 3)) + # print('After concatenate') + # os.system('rm {}'.format(os.path.join(path_dir, 'tsdf_values_tmp_*.npy'))) + # os.system('rm {}'.format(os.path.join(path_dir, 'color_values_tmp_*.npy'))) + vertices, faces, normals, _ = measure.marching_cubes( + tsdf_values_np, + level=0, + allow_degenerate=False, + ) + + vertices_indices = np.round(vertices).astype(int) + colors = color_values_np[vertices_indices[:, 0], vertices_indices[:, 1], vertices_indices[:, 2]] + + # move vertices back to world space + vertices = self.origin.cpu().numpy() + vertices * self.voxel_size + vertices = coord.inv_contract_np(vertices) + trimesh.Trimesh(vertices=vertices, + faces=faces, + normals=normals, + vertex_colors=colors, + ).export(path) + + @torch.no_grad() + def integrate_tsdf( + self, + c2w, + K, + depth_images, + color_images=None, + ): + """Integrates a batch of depth images into the TSDF. + + Args: + c2w: The camera extrinsics. + K: The camera intrinsics. + depth_images: The depth images to integrate. + color_images: The color images to integrate. + """ + batch_size = c2w.shape[0] + shape = self.voxel_coords.shape[1:] + + # Project voxel_coords into image space... + image_size = torch.tensor( + [depth_images.shape[-1], depth_images.shape[-2]], device=self.device + ) # [width, height] + + # make voxel_coords homogeneous + voxel_world_coords = self.voxel_world_coords.expand(batch_size, + *self.voxel_world_coords.shape[1:]) # [batch, 4, N] + + voxel_cam_coords = torch.bmm(torch.inverse(c2w), voxel_world_coords) # [batch, 4, N] + + # flip the z axis + voxel_cam_coords[:, 2, :] = -voxel_cam_coords[:, 2, :] + # flip the y axis + voxel_cam_coords[:, 1, :] = -voxel_cam_coords[:, 1, :] + + # # we need the distance of the point to the camera, not the z coordinate + # # TODO: why is this not the z coordinate? + # voxel_depth = torch.sqrt(torch.sum(voxel_cam_coords[:, :3, :] ** 2, dim=-2, keepdim=True)) # [batch, 1, N] + + voxel_cam_coords_z = voxel_cam_coords[:, 2:3, :] + voxel_depth = voxel_cam_coords_z + + voxel_cam_points = torch.bmm(K[None].expand(batch_size, -1, -1), + voxel_cam_coords[:, 0:3, :] / voxel_cam_coords_z) # [batch, 3, N] + voxel_pixel_coords = voxel_cam_points[:, :2, :] # [batch, 2, N] + + # Sample the depth images with grid sample... + + grid = voxel_pixel_coords.permute(0, 2, 1) # [batch, N, 2] + # normalize grid to [-1, 1] + grid = 2.0 * grid / image_size.view(1, 1, 2) - 1.0 # [batch, N, 2] + grid = grid[:, None] # [batch, 1, N, 2] + # depth + sampled_depth = F.grid_sample( + input=depth_images, grid=grid, mode="nearest", padding_mode="zeros", align_corners=False + ) # [batch, N, 1] + sampled_depth = sampled_depth.squeeze(2) # [batch, 1, N] + # colors + sampled_colors = None + if color_images is not None: + sampled_colors = F.grid_sample( + input=color_images, grid=grid, mode="nearest", padding_mode="zeros", align_corners=False + ) # [batch, N, 3] + sampled_colors = sampled_colors.squeeze(2) # [batch, 3, N] + + dist = sampled_depth - voxel_depth # [batch, 1, N] + + # x = self.voxel_world_coords[:, :3].permute(0, 2, 1) + # eps = torch.finfo(x.dtype).eps + # x_mag_sq = torch.sum(x ** 2, dim=-1).clamp_min(eps) + # truncation_weight = torch.where(x_mag_sq <= 1, torch.ones_like(x_mag_sq), + # ((2 * torch.sqrt(x_mag_sq) - 1) / x_mag_sq)) + # truncation = truncation_weight.reciprocal() * self.truncation + + truncation = self.truncation + + tsdf_values = torch.clamp(dist / truncation, min=-1.0, max=1.0) # [batch, 1, N] + valid_points = (voxel_depth > 0) & (sampled_depth > 0) & (dist > -self.truncation) # [batch, 1, N] + + # Sequentially update the TSDF... + for i in range(batch_size): + valid_points_i = valid_points[i] + valid_points_i_shape = valid_points_i.view(*shape) # [xdim, ydim, zdim] + + # the old values + old_tsdf_values_i = self.values[valid_points_i_shape] + old_weights_i = self.weights[valid_points_i_shape] + + # the new values + # TODO: let the new weight be configurable + new_tsdf_values_i = tsdf_values[i][valid_points_i] + new_weights_i = 1.0 + + total_weights = old_weights_i + new_weights_i + + self.values[valid_points_i_shape] = (old_tsdf_values_i * old_weights_i + + new_tsdf_values_i * new_weights_i) / total_weights + # self.weights[valid_points_i_shape] = torch.clamp(total_weights, max=1.0) + self.weights[valid_points_i_shape] = total_weights + + if sampled_colors is not None: + old_colors_i = self.colors[valid_points_i_shape] # [M, 3] + new_colors_i = sampled_colors[i][:, valid_points_i.squeeze(0)].permute(1, 0) # [M, 3] + self.colors[valid_points_i_shape] = (old_colors_i * old_weights_i[:, None] + + new_colors_i * new_weights_i) / total_weights[:, None] + + +def main(unused_argv): + config = configs.load_config() + config.compute_visibility = True + + config.exp_path = os.path.join("exp", config.exp_name) + config.mesh_path = os.path.join("exp", config.exp_name, "mesh") + config.checkpoint_dir = os.path.join(config.exp_path, 'checkpoints') + os.makedirs(config.mesh_path, exist_ok=True) + + # accelerator for DDP + accelerator = accelerate.Accelerator() + device = accelerator.device + + # setup logger + logging.basicConfig( + format="%(asctime)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + force=True, + handlers=[logging.StreamHandler(sys.stdout), + logging.FileHandler(os.path.join(config.exp_path, 'log_extract.txt'))], + level=logging.INFO, + ) + sys.excepthook = utils.handle_exception + logger = accelerate.logging.get_logger(__name__) + logger.info(config) + logger.info(accelerator.state, main_process_only=False) + + config.world_size = accelerator.num_processes + config.global_rank = accelerator.process_index + accelerate.utils.set_seed(config.seed, device_specific=True) + + # setup model and optimizer + model = models.Model(config=config) + model = accelerator.prepare(model) + step = checkpoints.restore_checkpoint(config.checkpoint_dir, accelerator, logger) + model.eval() + module = accelerator.unwrap_model(model) + + dataset = datasets.load_dataset('train', config.data_dir, config) + dataloader = torch.utils.data.DataLoader(np.arange(len(dataset)), + shuffle=False, + batch_size=1, + collate_fn=dataset.collate_fn, + ) + dataiter = iter(dataloader) + if config.rawnerf_mode: + postprocess_fn = dataset.metadata['postprocess_fn'] + else: + postprocess_fn = lambda z: z + + out_name = f'train_preds_step_{step}' + out_dir = os.path.join(config.mesh_path, out_name) + utils.makedirs(out_dir) + logger.info("Render trainset in {}".format(out_dir)) + + path_fn = lambda x: os.path.join(out_dir, x) + + # Ensure sufficient zero-padding of image indices in output filenames. + zpad = max(3, len(str(dataset.size - 1))) + idx_to_str = lambda idx: str(idx).zfill(zpad) + + for idx in range(dataset.size): + # If current image and next image both already exist, skip ahead. + idx_str = idx_to_str(idx) + curr_file = path_fn(f'color_{idx_str}.png') + if utils.file_exists(curr_file): + logger.info(f'Image {idx + 1}/{dataset.size} already exists, skipping') + continue + batch = next(dataiter) + batch = tree_map(lambda x: x.to(accelerator.device) if x is not None else None, batch) + logger.info(f'Evaluating image {idx + 1}/{dataset.size}') + eval_start_time = time.time() + rendering = models.render_image(model, accelerator, + batch, False, 1, config) + + logger.info(f'Rendered in {(time.time() - eval_start_time):0.3f}s') + + if accelerator.is_main_process: # Only record via host 0. + rendering['rgb'] = postprocess_fn(rendering['rgb']) + rendering = tree_map(lambda x: x.detach().cpu().numpy() if x is not None else None, rendering) + utils.save_img_u8(rendering['rgb'], path_fn(f'color_{idx_str}.png')) + utils.save_img_f32(rendering['distance_mean'], + path_fn(f'distance_mean_{idx_str}.tiff')) + utils.save_img_f32(rendering['distance_median'], + path_fn(f'distance_median_{idx_str}.tiff')) + + # if accelerator.is_main_process: + tsdf = TSDF(config, accelerator) + + c2w = torch.from_numpy(dataset.camtoworlds[:, :3, :4]).float().to(device) + + # make c2w homogeneous + c2w = torch.cat([c2w, torch.zeros(c2w.shape[0], 1, 4, device=device)], dim=1) + c2w[:, 3, 3] = 1 + K = torch.from_numpy(dataset.pixtocams).float().to(device).inverse() + + logger.info('Reading images') + rgb_files = sorted(glob.glob(path_fn('color_*.png'))) + depth_files = sorted(glob.glob(path_fn('distance_median_*.tiff'))) + assert len(rgb_files) == len(depth_files) + color_images = [] + depth_images = [] + for rgb_file, depth_file in zip(tqdm(rgb_files, disable=not accelerator.is_main_process), depth_files): + color_images.append(utils.load_img(rgb_file) / 255) + depth_images.append(utils.load_img(depth_file)[..., None]) + + color_images = torch.tensor(np.array(color_images), device=device).permute(0, 3, 1, 2) # shape (N, 3, H, W) + depth_images = torch.tensor(np.array(depth_images), device=device).permute(0, 3, 1, 2) # shape (N, 1, H, W) + + batch_size = 1 + logger.info("Integrating the TSDF") + for i in tqdm(range(0, len(c2w), batch_size), disable=not accelerator.is_main_process): + tsdf.integrate_tsdf( + c2w[i: i + batch_size], + K, + depth_images[i: i + batch_size], + color_images=color_images[i: i + batch_size], + ) + + logger.info("Saving TSDF Mesh") + tsdf.export_mesh(os.path.join(config.mesh_path, "tsdf_mesh.ply")) + accelerator.wait_for_everyone() + logger.info('Finish extracting mesh using TSDF.') + + +if __name__ == '__main__': + with gin.config_scope('bake'): + app.run(main)