Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	add code
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitignore +23 -0
- README.md +1 -1
- app.py +277 -0
- arguments/__init__.py +113 -0
- assets/example/TT-family-3-views/000301.jpg +0 -0
- assets/example/TT-family-3-views/000388.jpg +0 -0
- assets/example/TT-family-3-views/000491.jpg +0 -0
- assets/example/dl3dv-ba55-3-views/frame_60.jpg +0 -0
- assets/example/dl3dv-ba55-3-views/frame_61.jpg +0 -0
- assets/example/dl3dv-ba55-3-views/frame_62.jpg +0 -0
- assets/example/sora-santorini-3-views/frame_00.jpg +0 -0
- assets/example/sora-santorini-3-views/frame_06.jpg +0 -0
- assets/example/sora-santorini-3-views/frame_12.jpg +0 -0
- assets/load/.gitkeep +0 -0
- coarse_init_infer.py +100 -0
- gaussian_renderer/__init__.py +144 -0
- gaussian_renderer/__init__3dgs.py +100 -0
- gaussian_renderer/network_gui.py +86 -0
- lpipsPyTorch/__init__.py +21 -0
- lpipsPyTorch/modules/lpips.py +36 -0
- lpipsPyTorch/modules/networks.py +96 -0
- lpipsPyTorch/modules/utils.py +30 -0
- render_by_interp.py +152 -0
- requirements.txt +17 -0
- scene/__init__.py +96 -0
- scene/cameras.py +71 -0
- scene/colmap_loader.py +294 -0
- scene/dataset_readers.py +363 -0
- scene/gaussian_model.py +502 -0
- submodules/diff-gaussian-rasterization/.gitignore +3 -0
- submodules/diff-gaussian-rasterization/.gitmodules +3 -0
- submodules/diff-gaussian-rasterization/CMakeLists.txt +36 -0
- submodules/diff-gaussian-rasterization/LICENSE.md +83 -0
- submodules/diff-gaussian-rasterization/README.md +19 -0
- submodules/diff-gaussian-rasterization/cuda_rasterizer/auxiliary.h +175 -0
- submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.cu +657 -0
- submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.h +65 -0
- submodules/diff-gaussian-rasterization/cuda_rasterizer/config.h +19 -0
- submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.cu +455 -0
- submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.h +66 -0
- submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer.h +88 -0
- submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.cu +434 -0
- submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h +74 -0
- submodules/diff-gaussian-rasterization/diff_gaussian_rasterization/__init__.py +221 -0
- submodules/diff-gaussian-rasterization/ext.cpp +19 -0
- submodules/diff-gaussian-rasterization/rasterize_points.cu +217 -0
- submodules/diff-gaussian-rasterization/rasterize_points.h +67 -0
- submodules/diff-gaussian-rasterization/setup.py +34 -0
- submodules/diff-gaussian-rasterization/third_party/glm/.appveyor.yml +92 -0
- submodules/diff-gaussian-rasterization/third_party/glm/.gitignore +61 -0
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,23 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            /.idea/
         | 
| 2 | 
            +
            /work_dirs*
         | 
| 3 | 
            +
            .vscode/
         | 
| 4 | 
            +
            /tmp
         | 
| 5 | 
            +
            /data
         | 
| 6 | 
            +
            /checkpoints
         | 
| 7 | 
            +
            *.so
         | 
| 8 | 
            +
            *.patch
         | 
| 9 | 
            +
            __pycache__/
         | 
| 10 | 
            +
            *.egg-info/
         | 
| 11 | 
            +
            /viz*
         | 
| 12 | 
            +
            /submit*
         | 
| 13 | 
            +
            build/
         | 
| 14 | 
            +
            *.pyd
         | 
| 15 | 
            +
            /cache*
         | 
| 16 | 
            +
            *.stl
         | 
| 17 | 
            +
            *.pth
         | 
| 18 | 
            +
            /venv/
         | 
| 19 | 
            +
            .nk8s
         | 
| 20 | 
            +
            *.mp4
         | 
| 21 | 
            +
            .vs
         | 
| 22 | 
            +
            /exp/
         | 
| 23 | 
            +
            /dev/
         | 
    	
        README.md
    CHANGED
    
    | @@ -6,7 +6,7 @@ colorTo: green | |
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
             
            sdk_version: 4.20.1
         | 
| 8 | 
             
            python_version: 3.10.13
         | 
| 9 | 
            -
            app_file:  | 
| 10 | 
             
            pinned: false
         | 
| 11 | 
             
            license: mit
         | 
| 12 | 
             
            short_description: Sparse-view SFM-free Gaussian Splatting in Seconds
         | 
|  | |
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
             
            sdk_version: 4.20.1
         | 
| 8 | 
             
            python_version: 3.10.13
         | 
| 9 | 
            +
            app_file: app.py
         | 
| 10 | 
             
            pinned: false
         | 
| 11 | 
             
            license: mit
         | 
| 12 | 
             
            short_description: Sparse-view SFM-free Gaussian Splatting in Seconds
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,277 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os, subprocess, shlex, sys, gc
         | 
| 2 | 
            +
            import time
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import shutil
         | 
| 6 | 
            +
            import argparse
         | 
| 7 | 
            +
            import gradio as gr
         | 
| 8 | 
            +
            import uuid
         | 
| 9 | 
            +
            import spaces
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            subprocess.run(shlex.split("pip install wheel/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl"))
         | 
| 12 | 
            +
            subprocess.run(shlex.split("pip install wheel/simple_knn-0.0.0-cp310-cp310-linux_x86_64.whl"))
         | 
| 13 | 
            +
            subprocess.run(shlex.split("pip install wheel/curope-0.0.0-cp310-cp310-linux_x86_64.whl"))
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            BASE_DIR = os.path.dirname(os.path.abspath(__file__))
         | 
| 16 | 
            +
            os.sys.path.append(os.path.abspath(os.path.join(BASE_DIR, "submodules", "dust3r")))
         | 
| 17 | 
            +
            # os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
         | 
| 18 | 
            +
            from dust3r.inference import inference
         | 
| 19 | 
            +
            from dust3r.model import AsymmetricCroCo3DStereo
         | 
| 20 | 
            +
            from dust3r.utils.device import to_numpy
         | 
| 21 | 
            +
            from dust3r.image_pairs import make_pairs
         | 
| 22 | 
            +
            from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
         | 
| 23 | 
            +
            from utils.dust3r_utils import compute_global_alignment, load_images, storePly, save_colmap_cameras, save_colmap_images
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from argparse import ArgumentParser, Namespace
         | 
| 26 | 
            +
            from arguments import ModelParams, PipelineParams, OptimizationParams
         | 
| 27 | 
            +
            from train_joint import training
         | 
| 28 | 
            +
            from render_by_interp import render_sets
         | 
| 29 | 
            +
            GRADIO_CACHE_FOLDER = './gradio_cache_folder'
         | 
| 30 | 
            +
            #############################################################################################################################################
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def get_dust3r_args_parser():
         | 
| 34 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 35 | 
            +
                parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
         | 
| 36 | 
            +
                parser.add_argument("--model_path", type=str, default="submodules/dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", help="path to the model weights")
         | 
| 37 | 
            +
                parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
         | 
| 38 | 
            +
                parser.add_argument("--batch_size", type=int, default=1)
         | 
| 39 | 
            +
                parser.add_argument("--schedule", type=str, default='linear')
         | 
| 40 | 
            +
                parser.add_argument("--lr", type=float, default=0.01)
         | 
| 41 | 
            +
                parser.add_argument("--niter", type=int, default=300)
         | 
| 42 | 
            +
                parser.add_argument("--focal_avg", type=bool, default=True)
         | 
| 43 | 
            +
                parser.add_argument("--n_views", type=int, default=3)
         | 
| 44 | 
            +
                parser.add_argument("--base_path", type=str, default=GRADIO_CACHE_FOLDER) 
         | 
| 45 | 
            +
                return parser
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            @spaces.GPU(duration=300)
         | 
| 49 | 
            +
            def process(inputfiles, input_path=None):
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                if input_path is not None:
         | 
| 52 | 
            +
                    imgs_path = './assets/example/' + input_path
         | 
| 53 | 
            +
                    imgs_names = sorted(os.listdir(imgs_path))
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    inputfiles = []
         | 
| 56 | 
            +
                    for imgs_name in imgs_names:
         | 
| 57 | 
            +
                        file_path = os.path.join(imgs_path, imgs_name)
         | 
| 58 | 
            +
                        print(file_path)
         | 
| 59 | 
            +
                        inputfiles.append(file_path)
         | 
| 60 | 
            +
                    print(inputfiles)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                # ------ (1) Coarse Geometric Initialization ------
         | 
| 63 | 
            +
                # os.system(f"rm -rf {GRADIO_CACHE_FOLDER}")
         | 
| 64 | 
            +
                parser = get_dust3r_args_parser()
         | 
| 65 | 
            +
                opt = parser.parse_args()
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                tmp_user_folder = str(uuid.uuid4()).replace("-", "")
         | 
| 68 | 
            +
                opt.img_base_path = os.path.join(opt.base_path, tmp_user_folder)
         | 
| 69 | 
            +
                img_folder_path = os.path.join(opt.img_base_path, "images")    
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                img_folder_path = os.path.join(opt.img_base_path, "images")    
         | 
| 72 | 
            +
                model = AsymmetricCroCo3DStereo.from_pretrained(opt.model_path).to(opt.device)
         | 
| 73 | 
            +
                os.makedirs(img_folder_path, exist_ok=True)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                opt.n_views = len(inputfiles)  
         | 
| 76 | 
            +
                if opt.n_views == 1:
         | 
| 77 | 
            +
                    raise gr.Error("The number of input images should be greater than 1.")
         | 
| 78 | 
            +
                print("Multiple images: ", inputfiles)
         | 
| 79 | 
            +
                for image_path in inputfiles:
         | 
| 80 | 
            +
                    if input_path is not None:
         | 
| 81 | 
            +
                        shutil.copy(image_path, img_folder_path)
         | 
| 82 | 
            +
                    else:
         | 
| 83 | 
            +
                        shutil.move(image_path, img_folder_path)
         | 
| 84 | 
            +
                train_img_list = sorted(os.listdir(img_folder_path))
         | 
| 85 | 
            +
                assert len(train_img_list)==opt.n_views, f"Number of images in the folder is not equal to {opt.n_views}"
         | 
| 86 | 
            +
                images, ori_size, imgs_resolution = load_images(img_folder_path, size=512) 
         | 
| 87 | 
            +
                resolutions_are_equal = len(set(imgs_resolution)) == 1
         | 
| 88 | 
            +
                if resolutions_are_equal == False:
         | 
| 89 | 
            +
                    raise gr.Error("The resolution of the input image should be the same.")
         | 
| 90 | 
            +
                print("ori_size", ori_size)
         | 
| 91 | 
            +
                start_time = time.time()
         | 
| 92 | 
            +
                ######################################################
         | 
| 93 | 
            +
                pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
         | 
| 94 | 
            +
                output = inference(pairs, model, opt.device, batch_size=opt.batch_size)
         | 
| 95 | 
            +
                output_colmap_path=img_folder_path.replace("images", "sparse/0")
         | 
| 96 | 
            +
                os.makedirs(output_colmap_path, exist_ok=True)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                scene = global_aligner(output, device=opt.device, mode=GlobalAlignerMode.PointCloudOptimizer)
         | 
| 99 | 
            +
                loss = compute_global_alignment(scene=scene, init="mst", niter=opt.niter, schedule=opt.schedule, lr=opt.lr, focal_avg=opt.focal_avg)
         | 
| 100 | 
            +
                scene = scene.clean_pointcloud()   
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                imgs = to_numpy(scene.imgs)
         | 
| 103 | 
            +
                focals = scene.get_focals()
         | 
| 104 | 
            +
                poses = to_numpy(scene.get_im_poses())
         | 
| 105 | 
            +
                pts3d = to_numpy(scene.get_pts3d())
         | 
| 106 | 
            +
                scene.min_conf_thr = float(scene.conf_trf(torch.tensor(1.0)))
         | 
| 107 | 
            +
                confidence_masks = to_numpy(scene.get_masks())
         | 
| 108 | 
            +
                intrinsics = to_numpy(scene.get_intrinsics())
         | 
| 109 | 
            +
                ######################################################
         | 
| 110 | 
            +
                end_time = time.time()
         | 
| 111 | 
            +
                print(f"Time taken for {opt.n_views} views: {end_time-start_time} seconds")
         | 
| 112 | 
            +
                save_colmap_cameras(ori_size, intrinsics, os.path.join(output_colmap_path, 'cameras.txt'))
         | 
| 113 | 
            +
                save_colmap_images(poses, os.path.join(output_colmap_path, 'images.txt'), train_img_list)
         | 
| 114 | 
            +
                pts_4_3dgs = np.concatenate([p[m] for p, m in zip(pts3d, confidence_masks)])
         | 
| 115 | 
            +
                color_4_3dgs = np.concatenate([p[m] for p, m in zip(imgs, confidence_masks)])
         | 
| 116 | 
            +
                color_4_3dgs = (color_4_3dgs * 255.0).astype(np.uint8)
         | 
| 117 | 
            +
                storePly(os.path.join(output_colmap_path, "points3D.ply"), pts_4_3dgs, color_4_3dgs)
         | 
| 118 | 
            +
                pts_4_3dgs_all = np.array(pts3d).reshape(-1, 3)
         | 
| 119 | 
            +
                np.save(output_colmap_path + "/pts_4_3dgs_all.npy", pts_4_3dgs_all)
         | 
| 120 | 
            +
                np.save(output_colmap_path + "/focal.npy", np.array(focals.cpu()))
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                ### save VRAM
         | 
| 123 | 
            +
                del scene
         | 
| 124 | 
            +
                torch.cuda.empty_cache()
         | 
| 125 | 
            +
                gc.collect()
         | 
| 126 | 
            +
                ##################################################################################################################################################
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                # ------ (2) Fast 3D-Gaussian Optimization ------
         | 
| 129 | 
            +
                parser = ArgumentParser(description="Training script parameters")
         | 
| 130 | 
            +
                lp = ModelParams(parser)
         | 
| 131 | 
            +
                op = OptimizationParams(parser)
         | 
| 132 | 
            +
                pp = PipelineParams(parser)
         | 
| 133 | 
            +
                parser.add_argument('--debug_from', type=int, default=-1)
         | 
| 134 | 
            +
                parser.add_argument("--test_iterations", nargs="+", type=int, default=[])
         | 
| 135 | 
            +
                parser.add_argument("--save_iterations", nargs="+", type=int, default=[])
         | 
| 136 | 
            +
                parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
         | 
| 137 | 
            +
                parser.add_argument("--start_checkpoint", type=str, default = None)
         | 
| 138 | 
            +
                parser.add_argument("--scene", type=str, default="demo")
         | 
| 139 | 
            +
                parser.add_argument("--n_views", type=int, default=3)
         | 
| 140 | 
            +
                parser.add_argument("--get_video", action="store_true")
         | 
| 141 | 
            +
                parser.add_argument("--optim_pose", type=bool, default=True)
         | 
| 142 | 
            +
                parser.add_argument("--skip_train", action="store_true")
         | 
| 143 | 
            +
                parser.add_argument("--skip_test", action="store_true")
         | 
| 144 | 
            +
                args = parser.parse_args(sys.argv[1:])
         | 
| 145 | 
            +
                args.save_iterations.append(args.iterations)
         | 
| 146 | 
            +
                args.model_path = opt.img_base_path + '/output/'    
         | 
| 147 | 
            +
                args.source_path = opt.img_base_path
         | 
| 148 | 
            +
                # args.model_path = GRADIO_CACHE_FOLDER + '/output/'    
         | 
| 149 | 
            +
                # args.source_path = GRADIO_CACHE_FOLDER
         | 
| 150 | 
            +
                args.iteration = 1000
         | 
| 151 | 
            +
                os.makedirs(args.model_path, exist_ok=True)
         | 
| 152 | 
            +
                training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args)
         | 
| 153 | 
            +
                ##################################################################################################################################################
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                # ------ (3) Render video by interpolation ------
         | 
| 156 | 
            +
                parser = ArgumentParser(description="Testing script parameters")
         | 
| 157 | 
            +
                model = ModelParams(parser, sentinel=True)
         | 
| 158 | 
            +
                pipeline = PipelineParams(parser)
         | 
| 159 | 
            +
                args.eval = True
         | 
| 160 | 
            +
                args.get_video = True
         | 
| 161 | 
            +
                args.n_views = opt.n_views
         | 
| 162 | 
            +
                render_sets(
         | 
| 163 | 
            +
                    model.extract(args),
         | 
| 164 | 
            +
                    args.iteration,
         | 
| 165 | 
            +
                    pipeline.extract(args),
         | 
| 166 | 
            +
                    args.skip_train,
         | 
| 167 | 
            +
                    args.skip_test,
         | 
| 168 | 
            +
                    args,
         | 
| 169 | 
            +
                )
         | 
| 170 | 
            +
                output_ply_path = opt.img_base_path + f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply'
         | 
| 171 | 
            +
                output_video_path = opt.img_base_path + f'/output/demo_{opt.n_views}_view.mp4'
         | 
| 172 | 
            +
                # output_ply_path = GRADIO_CACHE_FOLDER+ f'/output/point_cloud/iteration_{args.iteration}/point_cloud.ply'
         | 
| 173 | 
            +
                # output_video_path = GRADIO_CACHE_FOLDER+ f'/output/demo_{opt.n_views}_view.mp4'
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                return  output_video_path, output_ply_path, output_ply_path
         | 
| 176 | 
            +
                ##################################################################################################################################################
         | 
| 177 | 
            +
             | 
| 178 | 
            +
             | 
| 179 | 
            +
             | 
| 180 | 
            +
            _TITLE = '''InstantSplat'''
         | 
| 181 | 
            +
            _DESCRIPTION = '''
         | 
| 182 | 
            +
            <div style="display: flex; justify-content: center; align-items: center;">
         | 
| 183 | 
            +
                <div style="width: 100%; text-align: center; font-size: 30px;">
         | 
| 184 | 
            +
                    <strong>InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds</strong>
         | 
| 185 | 
            +
                </div>
         | 
| 186 | 
            +
            </div> 
         | 
| 187 | 
            +
            <p></p>
         | 
| 188 | 
            +
             | 
| 189 | 
            +
            <div align="center">
         | 
| 190 | 
            +
                <a style="display:inline-block" href="https://instantsplat.github.io/"><img src='https://img.shields.io/badge/Project_Page-1c7d45?logo=gumtree'></a> 
         | 
| 191 | 
            +
                <a style="display:inline-block" href="https://www.youtube.com/watch?v=fxf_ypd7eD8"><img src='https://img.shields.io/badge/Demo_Video-E33122?logo=Youtube'></a> 
         | 
| 192 | 
            +
                <a style="display:inline-block" href="https://arxiv.org/abs/2403.20309"><img src="https://img.shields.io/badge/ArXiv-2403.20309-b31b1b?logo=arxiv" alt='arxiv'></a>
         | 
| 193 | 
            +
            </div>
         | 
| 194 | 
            +
            <p></p>
         | 
| 195 | 
            +
             | 
| 196 | 
            +
            * Official demo of: [InstantSplat: Sparse-view SfM-free Gaussian Splatting in Seconds](https://instantsplat.github.io/).
         | 
| 197 | 
            +
            * Sparse-view examples for direct viewing: you can simply click the examples (in the bottom of the page), to quickly view the results on representative data.
         | 
| 198 | 
            +
            * Training speeds may slow if the resolution or number of images is large. To achieve performance comparable to what has been reported, please conduct tests on your own GPU (A100/4090).
         | 
| 199 | 
            +
            '''
         | 
| 200 | 
            +
             | 
| 201 | 
            +
             | 
| 202 | 
            +
                # <a style="display:inline-block" href="https://github.com/VITA-Group/LightGaussian"><img src="https://img.shields.io/badge/Source_Code-black?logo=Github" alt='Github Source Code'></a> 
         | 
| 203 | 
            +
            #  
         | 
| 204 | 
            +
            #     <a style="display:inline-block" href="https://www.nvidia.com/en-us/"><img src="https://img.shields.io/badge/Nvidia-575757?logo=nvidia" alt='Nvidia'></a>
         | 
| 205 | 
            +
            # * If InstantSplat is helpful, please give us a star ⭐ on Github. Thanks! <a style="display:inline-block; margin-left: .5em" href="https://github.com/VITA-Group/LightGaussian"><img src='https://img.shields.io/github/stars/VITA-Group/LightGaussian?style=social'/></a>
         | 
| 206 | 
            +
             | 
| 207 | 
            +
             | 
| 208 | 
            +
            # block = gr.Blocks(title=_TITLE).queue()
         | 
| 209 | 
            +
            block = gr.Blocks().queue()
         | 
| 210 | 
            +
            with block:
         | 
| 211 | 
            +
                with gr.Row():
         | 
| 212 | 
            +
                    with gr.Column(scale=1):
         | 
| 213 | 
            +
                        # gr.Markdown('# ' + _TITLE)
         | 
| 214 | 
            +
                        gr.Markdown(_DESCRIPTION)
         | 
| 215 | 
            +
                
         | 
| 216 | 
            +
                with gr.Row(variant='panel'):
         | 
| 217 | 
            +
                    with gr.Tab("Input"):
         | 
| 218 | 
            +
                        inputfiles = gr.File(file_count="multiple", label="images")
         | 
| 219 | 
            +
                        input_path = gr.Textbox(visible=False, label="example_path")
         | 
| 220 | 
            +
                        button_gen = gr.Button("RUN")
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                with gr.Row(variant='panel'):        
         | 
| 223 | 
            +
                    with gr.Tab("Output"):
         | 
| 224 | 
            +
                        with gr.Column(scale=2):
         | 
| 225 | 
            +
                            output_model = gr.Model3D(                    
         | 
| 226 | 
            +
                                            label="3D Model (Gaussian)",
         | 
| 227 | 
            +
                                            # height=300,
         | 
| 228 | 
            +
                                            interactive=False,
         | 
| 229 | 
            +
                                            # clear_color=[1.0, 1.0, 1.0, 1.0]
         | 
| 230 | 
            +
                                        )
         | 
| 231 | 
            +
                            output_file = gr.File(label="ply")
         | 
| 232 | 
            +
                        with gr.Column(scale=1):
         | 
| 233 | 
            +
                            output_video = gr.Video(label="video")
         | 
| 234 | 
            +
                            
         | 
| 235 | 
            +
                button_gen.click(process, inputs=[inputfiles], outputs=[ output_video, output_file, output_model])
         | 
| 236 | 
            +
                
         | 
| 237 | 
            +
                # gr.Examples(
         | 
| 238 | 
            +
                #     examples=[
         | 
| 239 | 
            +
                #         "sora-santorini-3-views",
         | 
| 240 | 
            +
                #         # "TT-family-3-views",
         | 
| 241 | 
            +
                #         # "dl3dv-ba55-3-views",
         | 
| 242 | 
            +
                #     ],
         | 
| 243 | 
            +
                #     inputs=[input_path],
         | 
| 244 | 
            +
                #     outputs=[output_video, output_file, output_model],
         | 
| 245 | 
            +
                #     fn=lambda x: process(inputfiles=None, input_path=x),
         | 
| 246 | 
            +
                #     cache_examples=True,
         | 
| 247 | 
            +
                #     label='Sparse-view Examples'
         | 
| 248 | 
            +
                # )
         | 
| 249 | 
            +
            block.launch(server_name="0.0.0.0", share=False)
         | 
| 250 | 
            +
             | 
| 251 | 
            +
             | 
| 252 | 
            +
            # block = gr.Blocks(title=_TITLE).queue()
         | 
| 253 | 
            +
            # with block:
         | 
| 254 | 
            +
            #     with gr.Row():
         | 
| 255 | 
            +
            #         with gr.Column(scale=1):
         | 
| 256 | 
            +
            #             gr.Markdown('# ' + _TITLE)
         | 
| 257 | 
            +
            #     # gr.Markdown(_DESCRIPTION)
         | 
| 258 | 
            +
                
         | 
| 259 | 
            +
            #     with gr.Row(variant='panel'):
         | 
| 260 | 
            +
            #         with gr.Column(scale=1):
         | 
| 261 | 
            +
            #             with gr.Tab("Input"):
         | 
| 262 | 
            +
            #                 inputfiles = gr.File(file_count="multiple", label="images")
         | 
| 263 | 
            +
            #                 button_gen = gr.Button("RUN")
         | 
| 264 | 
            +
             | 
| 265 | 
            +
            #         with gr.Column(scale=2):
         | 
| 266 | 
            +
            #             with gr.Tab("Output"):
         | 
| 267 | 
            +
            #                 output_video = gr.Video(label="video")
         | 
| 268 | 
            +
            #                 output_model = gr.Model3D(                    
         | 
| 269 | 
            +
            #                                 label="3D Model (Gaussian)",
         | 
| 270 | 
            +
            #                                 height=300,
         | 
| 271 | 
            +
            #                                 interactive=False,
         | 
| 272 | 
            +
            #                             )
         | 
| 273 | 
            +
            #                 output_file = gr.File(label="ply")
         | 
| 274 | 
            +
             | 
| 275 | 
            +
            #         button_gen.click(process, inputs=[inputfiles], outputs=[ output_video, output_file, output_model])
         | 
| 276 | 
            +
             | 
| 277 | 
            +
            # block.launch(server_name="0.0.0.0", share=False)
         | 
    	
        arguments/__init__.py
    ADDED
    
    | @@ -0,0 +1,113 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            # Copyright (C) 2023, Inria
         | 
| 3 | 
            +
            # GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
            # All rights reserved.
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            # This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
            # under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # For inquiries contact  [email protected]
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from argparse import ArgumentParser, Namespace
         | 
| 13 | 
            +
            import sys
         | 
| 14 | 
            +
            import os
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            class GroupParams:
         | 
| 17 | 
            +
                pass
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            class ParamGroup:
         | 
| 20 | 
            +
                def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
         | 
| 21 | 
            +
                    group = parser.add_argument_group(name)
         | 
| 22 | 
            +
                    for key, value in vars(self).items():
         | 
| 23 | 
            +
                        shorthand = False
         | 
| 24 | 
            +
                        if key.startswith("_"):
         | 
| 25 | 
            +
                            shorthand = True
         | 
| 26 | 
            +
                            key = key[1:]
         | 
| 27 | 
            +
                        t = type(value)
         | 
| 28 | 
            +
                        value = value if not fill_none else None 
         | 
| 29 | 
            +
                        if shorthand:
         | 
| 30 | 
            +
                            if t == bool:
         | 
| 31 | 
            +
                                group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
         | 
| 32 | 
            +
                            else:
         | 
| 33 | 
            +
                                group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
         | 
| 34 | 
            +
                        else:
         | 
| 35 | 
            +
                            if t == bool:
         | 
| 36 | 
            +
                                group.add_argument("--" + key, default=value, action="store_true")
         | 
| 37 | 
            +
                            else:
         | 
| 38 | 
            +
                                group.add_argument("--" + key, default=value, type=t)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                def extract(self, args):
         | 
| 41 | 
            +
                    group = GroupParams()
         | 
| 42 | 
            +
                    for arg in vars(args).items():
         | 
| 43 | 
            +
                        if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
         | 
| 44 | 
            +
                            setattr(group, arg[0], arg[1])
         | 
| 45 | 
            +
                    return group
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            class ModelParams(ParamGroup): 
         | 
| 48 | 
            +
                def __init__(self, parser, sentinel=False):
         | 
| 49 | 
            +
                    self.sh_degree = 3
         | 
| 50 | 
            +
                    self._source_path = ""
         | 
| 51 | 
            +
                    self._model_path = ""
         | 
| 52 | 
            +
                    self._images = "images"
         | 
| 53 | 
            +
                    self._resolution = -1
         | 
| 54 | 
            +
                    self._white_background = False
         | 
| 55 | 
            +
                    self.data_device = "cuda"
         | 
| 56 | 
            +
                    self.eval = False
         | 
| 57 | 
            +
                    super().__init__(parser, "Loading Parameters", sentinel)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def extract(self, args):
         | 
| 60 | 
            +
                    g = super().extract(args)
         | 
| 61 | 
            +
                    g.source_path = os.path.abspath(g.source_path)
         | 
| 62 | 
            +
                    return g
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            class PipelineParams(ParamGroup):
         | 
| 65 | 
            +
                def __init__(self, parser):
         | 
| 66 | 
            +
                    self.convert_SHs_python = False
         | 
| 67 | 
            +
                    self.compute_cov3D_python = False
         | 
| 68 | 
            +
                    self.debug = False
         | 
| 69 | 
            +
                    super().__init__(parser, "Pipeline Parameters")
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            class OptimizationParams(ParamGroup):
         | 
| 72 | 
            +
                def __init__(self, parser):
         | 
| 73 | 
            +
                    self.iterations = 1000
         | 
| 74 | 
            +
                    # self.iterations = 30_000
         | 
| 75 | 
            +
                    self.position_lr_init =  0.00016
         | 
| 76 | 
            +
                    self.position_lr_final = 0.0000016
         | 
| 77 | 
            +
                    self.position_lr_delay_mult = 0.01
         | 
| 78 | 
            +
                    self.position_lr_max_steps = 30_000
         | 
| 79 | 
            +
                    self.feature_lr = 0.0025
         | 
| 80 | 
            +
                    self.opacity_lr = 0.05
         | 
| 81 | 
            +
                    self.scaling_lr = 0.005
         | 
| 82 | 
            +
                    self.rotation_lr = 0.001
         | 
| 83 | 
            +
                    self.percent_dense = 0.01
         | 
| 84 | 
            +
                    self.lambda_dssim = 0.2
         | 
| 85 | 
            +
                    self.densification_interval = 100
         | 
| 86 | 
            +
                    self.opacity_reset_interval = 3000
         | 
| 87 | 
            +
                    self.densify_from_iter = 500
         | 
| 88 | 
            +
                    self.densify_until_iter = 15_000
         | 
| 89 | 
            +
                    self.densify_grad_threshold = 0.0002
         | 
| 90 | 
            +
                    self.random_background = False
         | 
| 91 | 
            +
                    super().__init__(parser, "Optimization Parameters")
         | 
| 92 | 
            +
             | 
| 93 | 
            +
            def get_combined_args(parser : ArgumentParser):
         | 
| 94 | 
            +
                cmdlne_string = sys.argv[1:]
         | 
| 95 | 
            +
                cfgfile_string = "Namespace()"
         | 
| 96 | 
            +
                args_cmdline = parser.parse_args(cmdlne_string)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                try:
         | 
| 99 | 
            +
                    cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
         | 
| 100 | 
            +
                    print("Looking for config file in", cfgfilepath)
         | 
| 101 | 
            +
                    with open(cfgfilepath) as cfg_file:
         | 
| 102 | 
            +
                        print("Config file found: {}".format(cfgfilepath))
         | 
| 103 | 
            +
                        cfgfile_string = cfg_file.read()
         | 
| 104 | 
            +
                except TypeError:
         | 
| 105 | 
            +
                    print("Config file not found at")
         | 
| 106 | 
            +
                    pass
         | 
| 107 | 
            +
                args_cfgfile = eval(cfgfile_string)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                merged_dict = vars(args_cfgfile).copy()
         | 
| 110 | 
            +
                for k,v in vars(args_cmdline).items():
         | 
| 111 | 
            +
                    if v != None:
         | 
| 112 | 
            +
                        merged_dict[k] = v
         | 
| 113 | 
            +
                return Namespace(**merged_dict)
         | 
    	
        assets/example/TT-family-3-views/000301.jpg
    ADDED
    
    |   | 
    	
        assets/example/TT-family-3-views/000388.jpg
    ADDED
    
    |   | 
    	
        assets/example/TT-family-3-views/000491.jpg
    ADDED
    
    |   | 
    	
        assets/example/dl3dv-ba55-3-views/frame_60.jpg
    ADDED
    
    |   | 
    	
        assets/example/dl3dv-ba55-3-views/frame_61.jpg
    ADDED
    
    |   | 
    	
        assets/example/dl3dv-ba55-3-views/frame_62.jpg
    ADDED
    
    |   | 
    	
        assets/example/sora-santorini-3-views/frame_00.jpg
    ADDED
    
    |   | 
    	
        assets/example/sora-santorini-3-views/frame_06.jpg
    ADDED
    
    |   | 
    	
        assets/example/sora-santorini-3-views/frame_12.jpg
    ADDED
    
    |   | 
    	
        assets/load/.gitkeep
    ADDED
    
    | 
            File without changes
         | 
    	
        coarse_init_infer.py
    ADDED
    
    | @@ -0,0 +1,100 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import shutil
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import argparse
         | 
| 6 | 
            +
            import time
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            BASE_DIR = os.path.dirname(os.path.abspath(__file__))
         | 
| 9 | 
            +
            os.sys.path.append(os.path.abspath(os.path.join(BASE_DIR, "submodules", "dust3r")))
         | 
| 10 | 
            +
            os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from dust3r.inference import inference
         | 
| 13 | 
            +
            from dust3r.model import AsymmetricCroCo3DStereo
         | 
| 14 | 
            +
            from dust3r.utils.device import to_numpy
         | 
| 15 | 
            +
            from dust3r.image_pairs import make_pairs
         | 
| 16 | 
            +
            from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
         | 
| 17 | 
            +
            from utils.dust3r_utils import  compute_global_alignment, load_images, storePly, save_colmap_cameras, save_colmap_images
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            def get_args_parser():
         | 
| 20 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 21 | 
            +
                parser.add_argument("--image_size", type=int, default=512, choices=[512, 224], help="image size")
         | 
| 22 | 
            +
                # parser.add_argument("--model_path", type=str, default="./checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", help="path to the model weights")
         | 
| 23 | 
            +
                parser.add_argument("--model_path", type=str, default="submodules/dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", help="path to the model weights")
         | 
| 24 | 
            +
                parser.add_argument("--device", type=str, default='cuda', help="pytorch device")
         | 
| 25 | 
            +
                parser.add_argument("--batch_size", type=int, default=1)
         | 
| 26 | 
            +
                parser.add_argument("--schedule", type=str, default='linear')
         | 
| 27 | 
            +
                parser.add_argument("--lr", type=float, default=0.01)
         | 
| 28 | 
            +
                parser.add_argument("--niter", type=int, default=300)
         | 
| 29 | 
            +
                parser.add_argument("--focal_avg", action="store_true")
         | 
| 30 | 
            +
                # parser.add_argument("--focal_avg", type=bool, default=True)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                parser.add_argument("--llffhold", type=int, default=2)
         | 
| 33 | 
            +
                parser.add_argument("--n_views", type=int, default=12)
         | 
| 34 | 
            +
                parser.add_argument("--img_base_path", type=str, default="/home/workspace/datasets/instantsplat/Tanks/Barn/24_views")
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                return parser
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            if __name__ == '__main__':
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                parser = get_args_parser()
         | 
| 41 | 
            +
                args = parser.parse_args()
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                model_path = args.model_path
         | 
| 44 | 
            +
                device = args.device
         | 
| 45 | 
            +
                batch_size = args.batch_size
         | 
| 46 | 
            +
                schedule = args.schedule
         | 
| 47 | 
            +
                lr = args.lr
         | 
| 48 | 
            +
                niter = args.niter
         | 
| 49 | 
            +
                n_views = args.n_views
         | 
| 50 | 
            +
                img_base_path = args.img_base_path
         | 
| 51 | 
            +
                img_folder_path = os.path.join(img_base_path, "images")
         | 
| 52 | 
            +
                os.makedirs(img_folder_path, exist_ok=True)
         | 
| 53 | 
            +
                model = AsymmetricCroCo3DStereo.from_pretrained(model_path).to(device)
         | 
| 54 | 
            +
                ##########################################################################################################################################################################################
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                train_img_list = sorted(os.listdir(img_folder_path))
         | 
| 57 | 
            +
                assert len(train_img_list)==n_views, f"Number of images ({len(train_img_list)}) in the folder ({img_folder_path}) is not equal to {n_views}"
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                # if len(os.listdir(img_folder_path)) != len(train_img_list):
         | 
| 60 | 
            +
                #     for img_name in train_img_list:
         | 
| 61 | 
            +
                #         src_path = os.path.join(img_base_path, "images", img_name)
         | 
| 62 | 
            +
                #         tgt_path = os.path.join(img_folder_path, img_name)
         | 
| 63 | 
            +
                #         print(src_path, tgt_path)
         | 
| 64 | 
            +
                #         shutil.copy(src_path, tgt_path)
         | 
| 65 | 
            +
                images, ori_size = load_images(img_folder_path, size=512)
         | 
| 66 | 
            +
                print("ori_size", ori_size)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                start_time = time.time()
         | 
| 69 | 
            +
                ##########################################################################################################################################################################################
         | 
| 70 | 
            +
                pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
         | 
| 71 | 
            +
                output = inference(pairs, model, args.device, batch_size=batch_size)
         | 
| 72 | 
            +
                output_colmap_path=img_folder_path.replace("images", "sparse/0")
         | 
| 73 | 
            +
                os.makedirs(output_colmap_path, exist_ok=True)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                scene = global_aligner(output, device=args.device, mode=GlobalAlignerMode.PointCloudOptimizer)
         | 
| 76 | 
            +
                loss = compute_global_alignment(scene=scene, init="mst", niter=niter, schedule=schedule, lr=lr, focal_avg=args.focal_avg)
         | 
| 77 | 
            +
                scene = scene.clean_pointcloud()
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                imgs = to_numpy(scene.imgs)
         | 
| 80 | 
            +
                focals = scene.get_focals()
         | 
| 81 | 
            +
                poses = to_numpy(scene.get_im_poses())
         | 
| 82 | 
            +
                pts3d = to_numpy(scene.get_pts3d())
         | 
| 83 | 
            +
                scene.min_conf_thr = float(scene.conf_trf(torch.tensor(1.0)))
         | 
| 84 | 
            +
                confidence_masks = to_numpy(scene.get_masks())
         | 
| 85 | 
            +
                intrinsics = to_numpy(scene.get_intrinsics())
         | 
| 86 | 
            +
                ##########################################################################################################################################################################################
         | 
| 87 | 
            +
                end_time = time.time()
         | 
| 88 | 
            +
                print(f"Time taken for {n_views} views: {end_time-start_time} seconds")
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                # save
         | 
| 91 | 
            +
                save_colmap_cameras(ori_size, intrinsics, os.path.join(output_colmap_path, 'cameras.txt'))
         | 
| 92 | 
            +
                save_colmap_images(poses, os.path.join(output_colmap_path, 'images.txt'), train_img_list)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                pts_4_3dgs = np.concatenate([p[m] for p, m in zip(pts3d, confidence_masks)])
         | 
| 95 | 
            +
                color_4_3dgs = np.concatenate([p[m] for p, m in zip(imgs, confidence_masks)])
         | 
| 96 | 
            +
                color_4_3dgs = (color_4_3dgs * 255.0).astype(np.uint8)
         | 
| 97 | 
            +
                storePly(os.path.join(output_colmap_path, "points3D.ply"), pts_4_3dgs, color_4_3dgs)
         | 
| 98 | 
            +
                pts_4_3dgs_all = np.array(pts3d).reshape(-1, 3)
         | 
| 99 | 
            +
                np.save(output_colmap_path + "/pts_4_3dgs_all.npy", pts_4_3dgs_all)
         | 
| 100 | 
            +
                np.save(output_colmap_path + "/focal.npy", np.array(focals.cpu()))
         | 
    	
        gaussian_renderer/__init__.py
    ADDED
    
    | @@ -0,0 +1,144 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            # Copyright (C) 2023, Inria
         | 
| 3 | 
            +
            # GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
            # All rights reserved.
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            # This software is free for non-commercial, research and evaluation use
         | 
| 7 | 
            +
            # under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # For inquiries contact  [email protected]
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            import math
         | 
| 14 | 
            +
            from diff_gaussian_rasterization import (
         | 
| 15 | 
            +
                GaussianRasterizationSettings,
         | 
| 16 | 
            +
                GaussianRasterizer,
         | 
| 17 | 
            +
            )
         | 
| 18 | 
            +
            from scene.gaussian_model import GaussianModel
         | 
| 19 | 
            +
            from utils.sh_utils import eval_sh
         | 
| 20 | 
            +
            from utils.pose_utils import get_camera_from_tensor, quadmultiply
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            def render(
         | 
| 24 | 
            +
                viewpoint_camera,
         | 
| 25 | 
            +
                pc: GaussianModel,
         | 
| 26 | 
            +
                pipe,
         | 
| 27 | 
            +
                bg_color: torch.Tensor,
         | 
| 28 | 
            +
                scaling_modifier=1.0,
         | 
| 29 | 
            +
                override_color=None,
         | 
| 30 | 
            +
                camera_pose=None,
         | 
| 31 | 
            +
            ):
         | 
| 32 | 
            +
                """
         | 
| 33 | 
            +
                Render the scene.
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                Background tensor (bg_color) must be on GPU!
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
         | 
| 39 | 
            +
                screenspace_points = (
         | 
| 40 | 
            +
                    torch.zeros_like(
         | 
| 41 | 
            +
                        pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda"
         | 
| 42 | 
            +
                    )
         | 
| 43 | 
            +
                    + 0
         | 
| 44 | 
            +
                )
         | 
| 45 | 
            +
                try:
         | 
| 46 | 
            +
                    screenspace_points.retain_grad()
         | 
| 47 | 
            +
                except:
         | 
| 48 | 
            +
                    pass
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                # Set up rasterization configuration
         | 
| 51 | 
            +
                tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
         | 
| 52 | 
            +
                tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                # Set camera pose as identity. Then, we will transform the Gaussians around camera_pose
         | 
| 55 | 
            +
                w2c = torch.eye(4).cuda()
         | 
| 56 | 
            +
                projmatrix = (
         | 
| 57 | 
            +
                    w2c.unsqueeze(0).bmm(viewpoint_camera.projection_matrix.unsqueeze(0))
         | 
| 58 | 
            +
                ).squeeze(0)
         | 
| 59 | 
            +
                camera_pos = w2c.inverse()[3, :3]
         | 
| 60 | 
            +
                raster_settings = GaussianRasterizationSettings(
         | 
| 61 | 
            +
                    image_height=int(viewpoint_camera.image_height),
         | 
| 62 | 
            +
                    image_width=int(viewpoint_camera.image_width),
         | 
| 63 | 
            +
                    tanfovx=tanfovx,
         | 
| 64 | 
            +
                    tanfovy=tanfovy,
         | 
| 65 | 
            +
                    bg=bg_color,
         | 
| 66 | 
            +
                    scale_modifier=scaling_modifier,
         | 
| 67 | 
            +
                    # viewmatrix=viewpoint_camera.world_view_transform,
         | 
| 68 | 
            +
                    # projmatrix=viewpoint_camera.full_proj_transform,
         | 
| 69 | 
            +
                    viewmatrix=w2c,
         | 
| 70 | 
            +
                    projmatrix=projmatrix,
         | 
| 71 | 
            +
                    sh_degree=pc.active_sh_degree,
         | 
| 72 | 
            +
                    # campos=viewpoint_camera.camera_center,
         | 
| 73 | 
            +
                    campos=camera_pos,
         | 
| 74 | 
            +
                    prefiltered=False,
         | 
| 75 | 
            +
                    debug=pipe.debug,
         | 
| 76 | 
            +
                )
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                rasterizer = GaussianRasterizer(raster_settings=raster_settings)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                # means3D = pc.get_xyz
         | 
| 81 | 
            +
                rel_w2c = get_camera_from_tensor(camera_pose)
         | 
| 82 | 
            +
                # Transform mean and rot of Gaussians to camera frame
         | 
| 83 | 
            +
                gaussians_xyz = pc._xyz.clone()
         | 
| 84 | 
            +
                gaussians_rot = pc._rotation.clone()
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                xyz_ones = torch.ones(gaussians_xyz.shape[0], 1).cuda().float()
         | 
| 87 | 
            +
                xyz_homo = torch.cat((gaussians_xyz, xyz_ones), dim=1)
         | 
| 88 | 
            +
                gaussians_xyz_trans = (rel_w2c @ xyz_homo.T).T[:, :3]
         | 
| 89 | 
            +
                gaussians_rot_trans = quadmultiply(camera_pose[:4], gaussians_rot)
         | 
| 90 | 
            +
                means3D = gaussians_xyz_trans
         | 
| 91 | 
            +
                means2D = screenspace_points
         | 
| 92 | 
            +
                opacity = pc.get_opacity
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
         | 
| 95 | 
            +
                # scaling / rotation by the rasterizer.
         | 
| 96 | 
            +
                scales = None
         | 
| 97 | 
            +
                rotations = None
         | 
| 98 | 
            +
                cov3D_precomp = None
         | 
| 99 | 
            +
                if pipe.compute_cov3D_python:
         | 
| 100 | 
            +
                    cov3D_precomp = pc.get_covariance(scaling_modifier)
         | 
| 101 | 
            +
                else:
         | 
| 102 | 
            +
                    scales = pc.get_scaling
         | 
| 103 | 
            +
                    rotations = gaussians_rot_trans  # pc.get_rotation
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
         | 
| 106 | 
            +
                # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
         | 
| 107 | 
            +
                shs = None
         | 
| 108 | 
            +
                colors_precomp = None
         | 
| 109 | 
            +
                if override_color is None:
         | 
| 110 | 
            +
                    if pipe.convert_SHs_python:
         | 
| 111 | 
            +
                        shs_view = pc.get_features.transpose(1, 2).view(
         | 
| 112 | 
            +
                            -1, 3, (pc.max_sh_degree + 1) ** 2
         | 
| 113 | 
            +
                        )
         | 
| 114 | 
            +
                        dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat(
         | 
| 115 | 
            +
                            pc.get_features.shape[0], 1
         | 
| 116 | 
            +
                        )
         | 
| 117 | 
            +
                        dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
         | 
| 118 | 
            +
                        sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
         | 
| 119 | 
            +
                        colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
         | 
| 120 | 
            +
                    else:
         | 
| 121 | 
            +
                        shs = pc.get_features
         | 
| 122 | 
            +
                else:
         | 
| 123 | 
            +
                    colors_precomp = override_color
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                # Rasterize visible Gaussians to image, obtain their radii (on screen).
         | 
| 126 | 
            +
                rendered_image, radii = rasterizer(
         | 
| 127 | 
            +
                    means3D=means3D,
         | 
| 128 | 
            +
                    means2D=means2D,
         | 
| 129 | 
            +
                    shs=shs,
         | 
| 130 | 
            +
                    colors_precomp=colors_precomp,
         | 
| 131 | 
            +
                    opacities=opacity,
         | 
| 132 | 
            +
                    scales=scales,
         | 
| 133 | 
            +
                    rotations=rotations,
         | 
| 134 | 
            +
                    cov3D_precomp=cov3D_precomp,
         | 
| 135 | 
            +
                )
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
         | 
| 138 | 
            +
                # They will be excluded from value updates used in the splitting criteria.
         | 
| 139 | 
            +
                return {
         | 
| 140 | 
            +
                    "render": rendered_image,
         | 
| 141 | 
            +
                    "viewspace_points": screenspace_points,
         | 
| 142 | 
            +
                    "visibility_filter": radii > 0,
         | 
| 143 | 
            +
                    "radii": radii,
         | 
| 144 | 
            +
                }
         | 
    	
        gaussian_renderer/__init__3dgs.py
    ADDED
    
    | @@ -0,0 +1,100 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            # Copyright (C) 2023, Inria
         | 
| 3 | 
            +
            # GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
            # All rights reserved.
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            # This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
            # under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # For inquiries contact  [email protected]
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            import math
         | 
| 14 | 
            +
            from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
         | 
| 15 | 
            +
            from scene.gaussian_model import GaussianModel
         | 
| 16 | 
            +
            from utils.sh_utils import eval_sh
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
                Render the scene. 
         | 
| 21 | 
            +
                
         | 
| 22 | 
            +
                Background tensor (bg_color) must be on GPU!
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
             
         | 
| 25 | 
            +
                # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
         | 
| 26 | 
            +
                screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
         | 
| 27 | 
            +
                try:
         | 
| 28 | 
            +
                    screenspace_points.retain_grad()
         | 
| 29 | 
            +
                except:
         | 
| 30 | 
            +
                    pass
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                # Set up rasterization configuration
         | 
| 33 | 
            +
                tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
         | 
| 34 | 
            +
                tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                raster_settings = GaussianRasterizationSettings(
         | 
| 37 | 
            +
                    image_height=int(viewpoint_camera.image_height),
         | 
| 38 | 
            +
                    image_width=int(viewpoint_camera.image_width),
         | 
| 39 | 
            +
                    tanfovx=tanfovx,
         | 
| 40 | 
            +
                    tanfovy=tanfovy,
         | 
| 41 | 
            +
                    bg=bg_color,
         | 
| 42 | 
            +
                    scale_modifier=scaling_modifier,
         | 
| 43 | 
            +
                    viewmatrix=viewpoint_camera.world_view_transform,
         | 
| 44 | 
            +
                    projmatrix=viewpoint_camera.full_proj_transform,
         | 
| 45 | 
            +
                    sh_degree=pc.active_sh_degree,
         | 
| 46 | 
            +
                    campos=viewpoint_camera.camera_center,
         | 
| 47 | 
            +
                    prefiltered=False,
         | 
| 48 | 
            +
                    debug=pipe.debug
         | 
| 49 | 
            +
                )
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                rasterizer = GaussianRasterizer(raster_settings=raster_settings)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                means3D = pc.get_xyz
         | 
| 54 | 
            +
                means2D = screenspace_points
         | 
| 55 | 
            +
                opacity = pc.get_opacity
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
         | 
| 58 | 
            +
                # scaling / rotation by the rasterizer.
         | 
| 59 | 
            +
                scales = None
         | 
| 60 | 
            +
                rotations = None
         | 
| 61 | 
            +
                cov3D_precomp = None
         | 
| 62 | 
            +
                if pipe.compute_cov3D_python:
         | 
| 63 | 
            +
                    cov3D_precomp = pc.get_covariance(scaling_modifier)
         | 
| 64 | 
            +
                else:
         | 
| 65 | 
            +
                    scales = pc.get_scaling
         | 
| 66 | 
            +
                    rotations = pc.get_rotation
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
         | 
| 69 | 
            +
                # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
         | 
| 70 | 
            +
                shs = None
         | 
| 71 | 
            +
                colors_precomp = None
         | 
| 72 | 
            +
                if override_color is None:
         | 
| 73 | 
            +
                    if pipe.convert_SHs_python:
         | 
| 74 | 
            +
                        shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
         | 
| 75 | 
            +
                        dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
         | 
| 76 | 
            +
                        dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
         | 
| 77 | 
            +
                        sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
         | 
| 78 | 
            +
                        colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
         | 
| 79 | 
            +
                    else:
         | 
| 80 | 
            +
                        shs = pc.get_features
         | 
| 81 | 
            +
                else:
         | 
| 82 | 
            +
                    colors_precomp = override_color
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                # Rasterize visible Gaussians to image, obtain their radii (on screen). 
         | 
| 85 | 
            +
                rendered_image, radii = rasterizer(
         | 
| 86 | 
            +
                    means3D = means3D,
         | 
| 87 | 
            +
                    means2D = means2D,
         | 
| 88 | 
            +
                    shs = shs,
         | 
| 89 | 
            +
                    colors_precomp = colors_precomp,
         | 
| 90 | 
            +
                    opacities = opacity,
         | 
| 91 | 
            +
                    scales = scales,
         | 
| 92 | 
            +
                    rotations = rotations,
         | 
| 93 | 
            +
                    cov3D_precomp = cov3D_precomp)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
         | 
| 96 | 
            +
                # They will be excluded from value updates used in the splitting criteria.
         | 
| 97 | 
            +
                return {"render": rendered_image,
         | 
| 98 | 
            +
                        "viewspace_points": screenspace_points,
         | 
| 99 | 
            +
                        "visibility_filter" : radii > 0,
         | 
| 100 | 
            +
                        "radii": radii}
         | 
    	
        gaussian_renderer/network_gui.py
    ADDED
    
    | @@ -0,0 +1,86 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            # Copyright (C) 2023, Inria
         | 
| 3 | 
            +
            # GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
            # All rights reserved.
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            # This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
            # under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # For inquiries contact  [email protected]
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            import traceback
         | 
| 14 | 
            +
            import socket
         | 
| 15 | 
            +
            import json
         | 
| 16 | 
            +
            from scene.cameras import MiniCam
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            host = "127.0.0.1"
         | 
| 19 | 
            +
            port = 6009
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            conn = None
         | 
| 22 | 
            +
            addr = None
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            def init(wish_host, wish_port):
         | 
| 27 | 
            +
                global host, port, listener
         | 
| 28 | 
            +
                host = wish_host
         | 
| 29 | 
            +
                port = wish_port
         | 
| 30 | 
            +
                listener.bind((host, port))
         | 
| 31 | 
            +
                listener.listen()
         | 
| 32 | 
            +
                listener.settimeout(0)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            def try_connect():
         | 
| 35 | 
            +
                global conn, addr, listener
         | 
| 36 | 
            +
                try:
         | 
| 37 | 
            +
                    conn, addr = listener.accept()
         | 
| 38 | 
            +
                    print(f"\nConnected by {addr}")
         | 
| 39 | 
            +
                    conn.settimeout(None)
         | 
| 40 | 
            +
                except Exception as inst:
         | 
| 41 | 
            +
                    pass
         | 
| 42 | 
            +
                        
         | 
| 43 | 
            +
            def read():
         | 
| 44 | 
            +
                global conn
         | 
| 45 | 
            +
                messageLength = conn.recv(4)
         | 
| 46 | 
            +
                messageLength = int.from_bytes(messageLength, 'little')
         | 
| 47 | 
            +
                message = conn.recv(messageLength)
         | 
| 48 | 
            +
                return json.loads(message.decode("utf-8"))
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            def send(message_bytes, verify):
         | 
| 51 | 
            +
                global conn
         | 
| 52 | 
            +
                if message_bytes != None:
         | 
| 53 | 
            +
                    conn.sendall(message_bytes)
         | 
| 54 | 
            +
                conn.sendall(len(verify).to_bytes(4, 'little'))
         | 
| 55 | 
            +
                conn.sendall(bytes(verify, 'ascii'))
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            def receive():
         | 
| 58 | 
            +
                message = read()
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                width = message["resolution_x"]
         | 
| 61 | 
            +
                height = message["resolution_y"]
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                if width != 0 and height != 0:
         | 
| 64 | 
            +
                    try:
         | 
| 65 | 
            +
                        do_training = bool(message["train"])
         | 
| 66 | 
            +
                        fovy = message["fov_y"]
         | 
| 67 | 
            +
                        fovx = message["fov_x"]
         | 
| 68 | 
            +
                        znear = message["z_near"]
         | 
| 69 | 
            +
                        zfar = message["z_far"]
         | 
| 70 | 
            +
                        do_shs_python = bool(message["shs_python"])
         | 
| 71 | 
            +
                        do_rot_scale_python = bool(message["rot_scale_python"])
         | 
| 72 | 
            +
                        keep_alive = bool(message["keep_alive"])
         | 
| 73 | 
            +
                        scaling_modifier = message["scaling_modifier"]
         | 
| 74 | 
            +
                        world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
         | 
| 75 | 
            +
                        world_view_transform[:,1] = -world_view_transform[:,1]
         | 
| 76 | 
            +
                        world_view_transform[:,2] = -world_view_transform[:,2]
         | 
| 77 | 
            +
                        full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
         | 
| 78 | 
            +
                        full_proj_transform[:,1] = -full_proj_transform[:,1]
         | 
| 79 | 
            +
                        custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
         | 
| 80 | 
            +
                    except Exception as e:
         | 
| 81 | 
            +
                        print("")
         | 
| 82 | 
            +
                        traceback.print_exc()
         | 
| 83 | 
            +
                        raise e
         | 
| 84 | 
            +
                    return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
         | 
| 85 | 
            +
                else:
         | 
| 86 | 
            +
                    return None, None, None, None, None, None
         | 
    	
        lpipsPyTorch/__init__.py
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from .modules.lpips import LPIPS
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def lpips(x: torch.Tensor,
         | 
| 7 | 
            +
                      y: torch.Tensor,
         | 
| 8 | 
            +
                      net_type: str = 'alex',
         | 
| 9 | 
            +
                      version: str = '0.1'):
         | 
| 10 | 
            +
                r"""Function that measures
         | 
| 11 | 
            +
                Learned Perceptual Image Patch Similarity (LPIPS).
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                Arguments:
         | 
| 14 | 
            +
                    x, y (torch.Tensor): the input tensors to compare.
         | 
| 15 | 
            +
                    net_type (str): the network type to compare the features: 
         | 
| 16 | 
            +
                                    'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
         | 
| 17 | 
            +
                    version (str): the version of LPIPS. Default: 0.1.
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
                device = x.device
         | 
| 20 | 
            +
                criterion = LPIPS(net_type, version).to(device)
         | 
| 21 | 
            +
                return criterion(x, y)
         | 
    	
        lpipsPyTorch/modules/lpips.py
    ADDED
    
    | @@ -0,0 +1,36 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from .networks import get_network, LinLayers
         | 
| 5 | 
            +
            from .utils import get_state_dict
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            class LPIPS(nn.Module):
         | 
| 9 | 
            +
                r"""Creates a criterion that measures
         | 
| 10 | 
            +
                Learned Perceptual Image Patch Similarity (LPIPS).
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                Arguments:
         | 
| 13 | 
            +
                    net_type (str): the network type to compare the features: 
         | 
| 14 | 
            +
                                    'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
         | 
| 15 | 
            +
                    version (str): the version of LPIPS. Default: 0.1.
         | 
| 16 | 
            +
                """
         | 
| 17 | 
            +
                def __init__(self, net_type: str = 'alex', version: str = '0.1'):
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                    assert version in ['0.1'], 'v0.1 is only supported now'
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                    super(LPIPS, self).__init__()
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    # pretrained network
         | 
| 24 | 
            +
                    self.net = get_network(net_type)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    # linear layers
         | 
| 27 | 
            +
                    self.lin = LinLayers(self.net.n_channels_list)
         | 
| 28 | 
            +
                    self.lin.load_state_dict(get_state_dict(net_type, version))
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def forward(self, x: torch.Tensor, y: torch.Tensor):
         | 
| 31 | 
            +
                    feat_x, feat_y = self.net(x), self.net(y)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
         | 
| 34 | 
            +
                    res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    return torch.sum(torch.cat(res, 0), 0, True)
         | 
    	
        lpipsPyTorch/modules/networks.py
    ADDED
    
    | @@ -0,0 +1,96 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Sequence
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from itertools import chain
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import torch.nn as nn
         | 
| 7 | 
            +
            from torchvision import models
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from .utils import normalize_activation
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def get_network(net_type: str):
         | 
| 13 | 
            +
                if net_type == 'alex':
         | 
| 14 | 
            +
                    return AlexNet()
         | 
| 15 | 
            +
                elif net_type == 'squeeze':
         | 
| 16 | 
            +
                    return SqueezeNet()
         | 
| 17 | 
            +
                elif net_type == 'vgg':
         | 
| 18 | 
            +
                    return VGG16()
         | 
| 19 | 
            +
                else:
         | 
| 20 | 
            +
                    raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            class LinLayers(nn.ModuleList):
         | 
| 24 | 
            +
                def __init__(self, n_channels_list: Sequence[int]):
         | 
| 25 | 
            +
                    super(LinLayers, self).__init__([
         | 
| 26 | 
            +
                        nn.Sequential(
         | 
| 27 | 
            +
                            nn.Identity(),
         | 
| 28 | 
            +
                            nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
         | 
| 29 | 
            +
                        ) for nc in n_channels_list
         | 
| 30 | 
            +
                    ])
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    for param in self.parameters():
         | 
| 33 | 
            +
                        param.requires_grad = False
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            class BaseNet(nn.Module):
         | 
| 37 | 
            +
                def __init__(self):
         | 
| 38 | 
            +
                    super(BaseNet, self).__init__()
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    # register buffer
         | 
| 41 | 
            +
                    self.register_buffer(
         | 
| 42 | 
            +
                        'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
         | 
| 43 | 
            +
                    self.register_buffer(
         | 
| 44 | 
            +
                        'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def set_requires_grad(self, state: bool):
         | 
| 47 | 
            +
                    for param in chain(self.parameters(), self.buffers()):
         | 
| 48 | 
            +
                        param.requires_grad = state
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def z_score(self, x: torch.Tensor):
         | 
| 51 | 
            +
                    return (x - self.mean) / self.std
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def forward(self, x: torch.Tensor):
         | 
| 54 | 
            +
                    x = self.z_score(x)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    output = []
         | 
| 57 | 
            +
                    for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
         | 
| 58 | 
            +
                        x = layer(x)
         | 
| 59 | 
            +
                        if i in self.target_layers:
         | 
| 60 | 
            +
                            output.append(normalize_activation(x))
         | 
| 61 | 
            +
                        if len(output) == len(self.target_layers):
         | 
| 62 | 
            +
                            break
         | 
| 63 | 
            +
                    return output
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            class SqueezeNet(BaseNet):
         | 
| 67 | 
            +
                def __init__(self):
         | 
| 68 | 
            +
                    super(SqueezeNet, self).__init__()
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    self.layers = models.squeezenet1_1(True).features
         | 
| 71 | 
            +
                    self.target_layers = [2, 5, 8, 10, 11, 12, 13]
         | 
| 72 | 
            +
                    self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    self.set_requires_grad(False)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
             | 
| 77 | 
            +
            class AlexNet(BaseNet):
         | 
| 78 | 
            +
                def __init__(self):
         | 
| 79 | 
            +
                    super(AlexNet, self).__init__()
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    self.layers = models.alexnet(True).features
         | 
| 82 | 
            +
                    self.target_layers = [2, 5, 8, 10, 12]
         | 
| 83 | 
            +
                    self.n_channels_list = [64, 192, 384, 256, 256]
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    self.set_requires_grad(False)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            class VGG16(BaseNet):
         | 
| 89 | 
            +
                def __init__(self):
         | 
| 90 | 
            +
                    super(VGG16, self).__init__()
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
         | 
| 93 | 
            +
                    self.target_layers = [4, 9, 16, 23, 30]
         | 
| 94 | 
            +
                    self.n_channels_list = [64, 128, 256, 512, 512]
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    self.set_requires_grad(False)
         | 
    	
        lpipsPyTorch/modules/utils.py
    ADDED
    
    | @@ -0,0 +1,30 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from collections import OrderedDict
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def normalize_activation(x, eps=1e-10):
         | 
| 7 | 
            +
                norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
         | 
| 8 | 
            +
                return x / (norm_factor + eps)
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
         | 
| 12 | 
            +
                # build url
         | 
| 13 | 
            +
                url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
         | 
| 14 | 
            +
                    + f'master/lpips/weights/v{version}/{net_type}.pth'
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                # download
         | 
| 17 | 
            +
                old_state_dict = torch.hub.load_state_dict_from_url(
         | 
| 18 | 
            +
                    url, progress=True,
         | 
| 19 | 
            +
                    map_location=None if torch.cuda.is_available() else torch.device('cpu')
         | 
| 20 | 
            +
                )
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                # rename keys
         | 
| 23 | 
            +
                new_state_dict = OrderedDict()
         | 
| 24 | 
            +
                for key, val in old_state_dict.items():
         | 
| 25 | 
            +
                    new_key = key
         | 
| 26 | 
            +
                    new_key = new_key.replace('lin', '')
         | 
| 27 | 
            +
                    new_key = new_key.replace('model.', '')
         | 
| 28 | 
            +
                    new_state_dict[new_key] = val
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                return new_state_dict
         | 
    	
        render_by_interp.py
    ADDED
    
    | @@ -0,0 +1,152 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            # Copyright (C) 2023, Inria
         | 
| 3 | 
            +
            # GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
            # All rights reserved.
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            # This software is free for non-commercial, research and evaluation use
         | 
| 7 | 
            +
            # under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # For inquiries contact  [email protected]
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            from scene import Scene
         | 
| 14 | 
            +
            import os
         | 
| 15 | 
            +
            from tqdm import tqdm
         | 
| 16 | 
            +
            from os import makedirs
         | 
| 17 | 
            +
            from gaussian_renderer import render
         | 
| 18 | 
            +
            import torchvision
         | 
| 19 | 
            +
            from utils.general_utils import safe_state
         | 
| 20 | 
            +
            from argparse import ArgumentParser
         | 
| 21 | 
            +
            from arguments import ModelParams, PipelineParams, get_combined_args
         | 
| 22 | 
            +
            from gaussian_renderer import GaussianModel
         | 
| 23 | 
            +
            from utils.pose_utils import get_tensor_from_camera
         | 
| 24 | 
            +
            from utils.camera_utils import generate_interpolated_path
         | 
| 25 | 
            +
            from utils.camera_utils import visualizer
         | 
| 26 | 
            +
            import cv2
         | 
| 27 | 
            +
            import numpy as np
         | 
| 28 | 
            +
            import imageio
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def save_interpolate_pose(model_path, iter, n_views):
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                org_pose = np.load(model_path + f"pose/pose_{iter}.npy")
         | 
| 34 | 
            +
                # visualizer(org_pose, ["green" for _ in org_pose], model_path + "pose/poses_optimized.png")
         | 
| 35 | 
            +
                # n_interp = int(10 * 30 / n_views)  # 10second, fps=30
         | 
| 36 | 
            +
                n_interp = int(5 * 30 / n_views)  # 5second, fps=30
         | 
| 37 | 
            +
                all_inter_pose = []
         | 
| 38 | 
            +
                for i in range(n_views-1):
         | 
| 39 | 
            +
                    tmp_inter_pose = generate_interpolated_path(poses=org_pose[i:i+2], n_interp=n_interp)
         | 
| 40 | 
            +
                    all_inter_pose.append(tmp_inter_pose)
         | 
| 41 | 
            +
                all_inter_pose = np.array(all_inter_pose).reshape(-1, 3, 4)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                inter_pose_list = []
         | 
| 44 | 
            +
                for p in all_inter_pose:
         | 
| 45 | 
            +
                    tmp_view = np.eye(4)
         | 
| 46 | 
            +
                    tmp_view[:3, :3] = p[:3, :3]
         | 
| 47 | 
            +
                    tmp_view[:3, 3] = p[:3, 3]
         | 
| 48 | 
            +
                    inter_pose_list.append(tmp_view)
         | 
| 49 | 
            +
                inter_pose = np.stack(inter_pose_list, 0)
         | 
| 50 | 
            +
                # visualizer(inter_pose, ["blue" for _ in inter_pose], model_path + "pose/poses_interpolated.png")
         | 
| 51 | 
            +
                np.save(model_path + "pose/pose_interpolated.npy", inter_pose)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def images_to_video(image_folder, output_video_path, fps=30):
         | 
| 55 | 
            +
                """
         | 
| 56 | 
            +
                Convert images in a folder to a video.
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                Args:
         | 
| 59 | 
            +
                - image_folder (str): The path to the folder containing the images.
         | 
| 60 | 
            +
                - output_video_path (str): The path where the output video will be saved.
         | 
| 61 | 
            +
                - fps (int): Frames per second for the output video.
         | 
| 62 | 
            +
                """
         | 
| 63 | 
            +
                images = []
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                for filename in sorted(os.listdir(image_folder)):
         | 
| 66 | 
            +
                    if filename.endswith(('.png', '.jpg', '.jpeg', '.JPG', '.PNG')):
         | 
| 67 | 
            +
                        image_path = os.path.join(image_folder, filename)
         | 
| 68 | 
            +
                        image = imageio.imread(image_path)
         | 
| 69 | 
            +
                        images.append(image)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                imageio.mimwrite(output_video_path, images, fps=fps)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
         | 
| 75 | 
            +
                render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
         | 
| 76 | 
            +
                makedirs(render_path, exist_ok=True)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                # for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
         | 
| 79 | 
            +
                for idx, view in enumerate(views):
         | 
| 80 | 
            +
                    camera_pose = get_tensor_from_camera(view.world_view_transform.transpose(0, 1))
         | 
| 81 | 
            +
                    rendering = render(
         | 
| 82 | 
            +
                        view, gaussians, pipeline, background, camera_pose=camera_pose
         | 
| 83 | 
            +
                    )["render"]
         | 
| 84 | 
            +
                    gt = view.original_image[0:3, :, :]
         | 
| 85 | 
            +
                    torchvision.utils.save_image(
         | 
| 86 | 
            +
                        rendering, os.path.join(render_path, "{0:05d}".format(idx) + ".png")
         | 
| 87 | 
            +
                    )
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            def render_sets(
         | 
| 91 | 
            +
                dataset: ModelParams,
         | 
| 92 | 
            +
                iteration: int,
         | 
| 93 | 
            +
                pipeline: PipelineParams,
         | 
| 94 | 
            +
                skip_train: bool,
         | 
| 95 | 
            +
                skip_test: bool,
         | 
| 96 | 
            +
                args,
         | 
| 97 | 
            +
            ):
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                # Applying interpolation
         | 
| 100 | 
            +
                save_interpolate_pose(dataset.model_path, iteration, args.n_views)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                with torch.no_grad():
         | 
| 103 | 
            +
                    gaussians = GaussianModel(dataset.sh_degree)
         | 
| 104 | 
            +
                    scene = Scene(dataset, gaussians, load_iteration=iteration, opt=args, shuffle=False)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
         | 
| 107 | 
            +
                    background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                # render interpolated views
         | 
| 110 | 
            +
                render_set(
         | 
| 111 | 
            +
                    dataset.model_path,
         | 
| 112 | 
            +
                    "interp",
         | 
| 113 | 
            +
                    scene.loaded_iter,
         | 
| 114 | 
            +
                    scene.getTrainCameras(),
         | 
| 115 | 
            +
                    gaussians,
         | 
| 116 | 
            +
                    pipeline,
         | 
| 117 | 
            +
                    background,
         | 
| 118 | 
            +
                )
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                if args.get_video:
         | 
| 121 | 
            +
                    image_folder = os.path.join(dataset.model_path, f'interp/ours_{args.iteration}/renders')
         | 
| 122 | 
            +
                    output_video_file = os.path.join(dataset.model_path, f'{args.scene}_{args.n_views}_view.mp4')
         | 
| 123 | 
            +
                    images_to_video(image_folder, output_video_file, fps=30)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
             | 
| 126 | 
            +
            if __name__ == "__main__":
         | 
| 127 | 
            +
                # Set up command line argument parser
         | 
| 128 | 
            +
                parser = ArgumentParser(description="Testing script parameters")
         | 
| 129 | 
            +
                model = ModelParams(parser, sentinel=True)
         | 
| 130 | 
            +
                pipeline = PipelineParams(parser)
         | 
| 131 | 
            +
                parser.add_argument("--iteration", default=-1, type=int)
         | 
| 132 | 
            +
                parser.add_argument("--skip_train", action="store_true")
         | 
| 133 | 
            +
                parser.add_argument("--skip_test", action="store_true")
         | 
| 134 | 
            +
                parser.add_argument("--quiet", action="store_true")
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                parser.add_argument("--get_video", action="store_true")
         | 
| 137 | 
            +
                parser.add_argument("--n_views", default=None, type=int)
         | 
| 138 | 
            +
                parser.add_argument("--scene", default=None, type=str)
         | 
| 139 | 
            +
                args = get_combined_args(parser)
         | 
| 140 | 
            +
                print("Rendering " + args.model_path)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                # Initialize system state (RNG)
         | 
| 143 | 
            +
                # safe_state(args.quiet)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                render_sets(
         | 
| 146 | 
            +
                    model.extract(args),
         | 
| 147 | 
            +
                    args.iteration,
         | 
| 148 | 
            +
                    pipeline.extract(args),
         | 
| 149 | 
            +
                    args.skip_train,
         | 
| 150 | 
            +
                    args.skip_test,
         | 
| 151 | 
            +
                    args,
         | 
| 152 | 
            +
                )
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,17 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            torch==2.2.0
         | 
| 2 | 
            +
            torchvision
         | 
| 3 | 
            +
            roma
         | 
| 4 | 
            +
            evo
         | 
| 5 | 
            +
            gradio==5.0.1
         | 
| 6 | 
            +
            matplotlib
         | 
| 7 | 
            +
            tqdm
         | 
| 8 | 
            +
            opencv-python
         | 
| 9 | 
            +
            scipy
         | 
| 10 | 
            +
            einops
         | 
| 11 | 
            +
            trimesh
         | 
| 12 | 
            +
            tensorboard
         | 
| 13 | 
            +
            pyglet<2
         | 
| 14 | 
            +
            huggingface-hub[torch]>=0.22
         | 
| 15 | 
            +
            plyfile
         | 
| 16 | 
            +
            imageio[ffmpeg]
         | 
| 17 | 
            +
            spaces
         | 
    	
        scene/__init__.py
    ADDED
    
    | @@ -0,0 +1,96 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            # Copyright (C) 2023, Inria
         | 
| 3 | 
            +
            # GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
            # All rights reserved.
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            # This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
            # under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # For inquiries contact  [email protected]
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import os
         | 
| 13 | 
            +
            import random
         | 
| 14 | 
            +
            import json
         | 
| 15 | 
            +
            from utils.system_utils import searchForMaxIteration
         | 
| 16 | 
            +
            from scene.dataset_readers import sceneLoadTypeCallbacks
         | 
| 17 | 
            +
            from scene.gaussian_model import GaussianModel
         | 
| 18 | 
            +
            from arguments import ModelParams
         | 
| 19 | 
            +
            from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            class Scene:
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                gaussians : GaussianModel
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, opt=None, shuffle=True, resolution_scales=[1.0]):
         | 
| 26 | 
            +
                    """b
         | 
| 27 | 
            +
                    :param path: Path to colmap scene main folder.
         | 
| 28 | 
            +
                    """
         | 
| 29 | 
            +
                    self.model_path = args.model_path
         | 
| 30 | 
            +
                    self.loaded_iter = None
         | 
| 31 | 
            +
                    self.gaussians = gaussians
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    if load_iteration:
         | 
| 34 | 
            +
                        if load_iteration == -1:
         | 
| 35 | 
            +
                            self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
         | 
| 36 | 
            +
                        else:
         | 
| 37 | 
            +
                            self.loaded_iter = load_iteration
         | 
| 38 | 
            +
                        print("Loading trained model at iteration {}".format(self.loaded_iter))
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    self.train_cameras = {}
         | 
| 41 | 
            +
                    self.test_cameras = {}
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    if os.path.exists(os.path.join(args.source_path, "sparse")):
         | 
| 44 | 
            +
                        scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval, args, opt)
         | 
| 45 | 
            +
                    elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
         | 
| 46 | 
            +
                        print("Found transforms_train.json file, assuming Blender data set!")
         | 
| 47 | 
            +
                        scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
         | 
| 48 | 
            +
                    else:
         | 
| 49 | 
            +
                        assert False, "Could not recognize scene type!"
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    if not self.loaded_iter:
         | 
| 52 | 
            +
                        with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
         | 
| 53 | 
            +
                            dest_file.write(src_file.read())
         | 
| 54 | 
            +
                        json_cams = []
         | 
| 55 | 
            +
                        camlist = []
         | 
| 56 | 
            +
                        if scene_info.test_cameras:
         | 
| 57 | 
            +
                            camlist.extend(scene_info.test_cameras)
         | 
| 58 | 
            +
                        if scene_info.train_cameras:
         | 
| 59 | 
            +
                            camlist.extend(scene_info.train_cameras)
         | 
| 60 | 
            +
                        for id, cam in enumerate(camlist):
         | 
| 61 | 
            +
                            json_cams.append(camera_to_JSON(id, cam))
         | 
| 62 | 
            +
                        with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
         | 
| 63 | 
            +
                            json.dump(json_cams, file)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    if shuffle:
         | 
| 66 | 
            +
                        random.shuffle(scene_info.train_cameras)  # Multi-res consistent random shuffling
         | 
| 67 | 
            +
                        random.shuffle(scene_info.test_cameras)  # Multi-res consistent random shuffling
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    self.cameras_extent = scene_info.nerf_normalization["radius"]
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    for resolution_scale in resolution_scales:
         | 
| 72 | 
            +
                        print("Loading Training Cameras")
         | 
| 73 | 
            +
                        self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
         | 
| 74 | 
            +
                        print('train_camera_num: ', len(self.train_cameras[resolution_scale]))
         | 
| 75 | 
            +
                        print("Loading Test Cameras")
         | 
| 76 | 
            +
                        self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
         | 
| 77 | 
            +
                        print('test_camera_num: ', len(self.test_cameras[resolution_scale]))
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    if self.loaded_iter:
         | 
| 80 | 
            +
                        self.gaussians.load_ply(os.path.join(self.model_path,
         | 
| 81 | 
            +
                                                                       "point_cloud",
         | 
| 82 | 
            +
                                                                       "iteration_" + str(self.loaded_iter),
         | 
| 83 | 
            +
                                                                       "point_cloud.ply"))
         | 
| 84 | 
            +
                    else:
         | 
| 85 | 
            +
                        self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
         | 
| 86 | 
            +
                        self.gaussians.init_RT_seq(self.train_cameras)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def save(self, iteration):
         | 
| 89 | 
            +
                    point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
         | 
| 90 | 
            +
                    self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                def getTrainCameras(self, scale=1.0):
         | 
| 93 | 
            +
                    return self.train_cameras[scale]
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def getTestCameras(self, scale=1.0):
         | 
| 96 | 
            +
                    return self.test_cameras[scale]
         | 
    	
        scene/cameras.py
    ADDED
    
    | @@ -0,0 +1,71 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            # Copyright (C) 2023, Inria
         | 
| 3 | 
            +
            # GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
            # All rights reserved.
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            # This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
            # under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # For inquiries contact  [email protected]
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            from torch import nn
         | 
| 14 | 
            +
            import numpy as np
         | 
| 15 | 
            +
            from utils.graphics_utils import getWorld2View2, getProjectionMatrix
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            class Camera(nn.Module):
         | 
| 18 | 
            +
                def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
         | 
| 19 | 
            +
                             image_name, uid,
         | 
| 20 | 
            +
                             trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
         | 
| 21 | 
            +
                             ):
         | 
| 22 | 
            +
                    super(Camera, self).__init__()
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    self.uid = uid
         | 
| 25 | 
            +
                    self.colmap_id = colmap_id
         | 
| 26 | 
            +
                    self.R = R
         | 
| 27 | 
            +
                    self.T = T
         | 
| 28 | 
            +
                    self.FoVx = FoVx
         | 
| 29 | 
            +
                    self.FoVy = FoVy
         | 
| 30 | 
            +
                    self.image_name = image_name
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    try:
         | 
| 33 | 
            +
                        self.data_device = torch.device(data_device)
         | 
| 34 | 
            +
                    except Exception as e:
         | 
| 35 | 
            +
                        print(e)
         | 
| 36 | 
            +
                        print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
         | 
| 37 | 
            +
                        self.data_device = torch.device("cuda")
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                    self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
         | 
| 40 | 
            +
                    self.image_width = self.original_image.shape[2]
         | 
| 41 | 
            +
                    self.image_height = self.original_image.shape[1]
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    if gt_alpha_mask is not None:
         | 
| 44 | 
            +
                        self.original_image *= gt_alpha_mask.to(self.data_device)
         | 
| 45 | 
            +
                    else:
         | 
| 46 | 
            +
                        self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    self.zfar = 100.0
         | 
| 49 | 
            +
                    self.znear = 0.01
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    self.trans = trans
         | 
| 52 | 
            +
                    self.scale = scale
         | 
| 53 | 
            +
                    
         | 
| 54 | 
            +
                    self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
         | 
| 55 | 
            +
                    self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
         | 
| 56 | 
            +
                    self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
         | 
| 57 | 
            +
                    self.camera_center = self.world_view_transform.inverse()[3, :3]
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            class MiniCam:
         | 
| 60 | 
            +
                def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
         | 
| 61 | 
            +
                    self.image_width = width
         | 
| 62 | 
            +
                    self.image_height = height    
         | 
| 63 | 
            +
                    self.FoVy = fovy
         | 
| 64 | 
            +
                    self.FoVx = fovx
         | 
| 65 | 
            +
                    self.znear = znear
         | 
| 66 | 
            +
                    self.zfar = zfar
         | 
| 67 | 
            +
                    self.world_view_transform = world_view_transform
         | 
| 68 | 
            +
                    self.full_proj_transform = full_proj_transform
         | 
| 69 | 
            +
                    view_inv = torch.inverse(self.world_view_transform)
         | 
| 70 | 
            +
                    self.camera_center = view_inv[3][:3]
         | 
| 71 | 
            +
             | 
    	
        scene/colmap_loader.py
    ADDED
    
    | @@ -0,0 +1,294 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            # Copyright (C) 2023, Inria
         | 
| 3 | 
            +
            # GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
            # All rights reserved.
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            # This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
            # under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # For inquiries contact  [email protected]
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
            import collections
         | 
| 14 | 
            +
            import struct
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            CameraModel = collections.namedtuple(
         | 
| 17 | 
            +
                "CameraModel", ["model_id", "model_name", "num_params"])
         | 
| 18 | 
            +
            Camera = collections.namedtuple(
         | 
| 19 | 
            +
                "Camera", ["id", "model", "width", "height", "params"])
         | 
| 20 | 
            +
            BaseImage = collections.namedtuple(
         | 
| 21 | 
            +
                "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
         | 
| 22 | 
            +
            Point3D = collections.namedtuple(
         | 
| 23 | 
            +
                "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
         | 
| 24 | 
            +
            CAMERA_MODELS = {
         | 
| 25 | 
            +
                CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
         | 
| 26 | 
            +
                CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
         | 
| 27 | 
            +
                CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
         | 
| 28 | 
            +
                CameraModel(model_id=3, model_name="RADIAL", num_params=5),
         | 
| 29 | 
            +
                CameraModel(model_id=4, model_name="OPENCV", num_params=8),
         | 
| 30 | 
            +
                CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
         | 
| 31 | 
            +
                CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
         | 
| 32 | 
            +
                CameraModel(model_id=7, model_name="FOV", num_params=5),
         | 
| 33 | 
            +
                CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
         | 
| 34 | 
            +
                CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
         | 
| 35 | 
            +
                CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
         | 
| 36 | 
            +
            }
         | 
| 37 | 
            +
            CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
         | 
| 38 | 
            +
                                     for camera_model in CAMERA_MODELS])
         | 
| 39 | 
            +
            CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
         | 
| 40 | 
            +
                                       for camera_model in CAMERA_MODELS])
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def qvec2rotmat(qvec):
         | 
| 44 | 
            +
                return np.array([
         | 
| 45 | 
            +
                    [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
         | 
| 46 | 
            +
                     2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
         | 
| 47 | 
            +
                     2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
         | 
| 48 | 
            +
                    [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
         | 
| 49 | 
            +
                     1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
         | 
| 50 | 
            +
                     2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
         | 
| 51 | 
            +
                    [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
         | 
| 52 | 
            +
                     2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
         | 
| 53 | 
            +
                     1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            def rotmat2qvec(R):
         | 
| 56 | 
            +
                Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
         | 
| 57 | 
            +
                K = np.array([
         | 
| 58 | 
            +
                    [Rxx - Ryy - Rzz, 0, 0, 0],
         | 
| 59 | 
            +
                    [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
         | 
| 60 | 
            +
                    [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
         | 
| 61 | 
            +
                    [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
         | 
| 62 | 
            +
                eigvals, eigvecs = np.linalg.eigh(K)
         | 
| 63 | 
            +
                qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
         | 
| 64 | 
            +
                if qvec[0] < 0:
         | 
| 65 | 
            +
                    qvec *= -1
         | 
| 66 | 
            +
                return qvec
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            class Image(BaseImage):
         | 
| 69 | 
            +
                def qvec2rotmat(self):
         | 
| 70 | 
            +
                    return qvec2rotmat(self.qvec)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
         | 
| 73 | 
            +
                """Read and unpack the next bytes from a binary file.
         | 
| 74 | 
            +
                :param fid:
         | 
| 75 | 
            +
                :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
         | 
| 76 | 
            +
                :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
         | 
| 77 | 
            +
                :param endian_character: Any of {@, =, <, >, !}
         | 
| 78 | 
            +
                :return: Tuple of read and unpacked values.
         | 
| 79 | 
            +
                """
         | 
| 80 | 
            +
                data = fid.read(num_bytes)
         | 
| 81 | 
            +
                return struct.unpack(endian_character + format_char_sequence, data)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            def read_points3D_text(path):
         | 
| 84 | 
            +
                """
         | 
| 85 | 
            +
                see: src/base/reconstruction.cc
         | 
| 86 | 
            +
                    void Reconstruction::ReadPoints3DText(const std::string& path)
         | 
| 87 | 
            +
                    void Reconstruction::WritePoints3DText(const std::string& path)
         | 
| 88 | 
            +
                """
         | 
| 89 | 
            +
                xyzs = None
         | 
| 90 | 
            +
                rgbs = None
         | 
| 91 | 
            +
                errors = None
         | 
| 92 | 
            +
                num_points = 0
         | 
| 93 | 
            +
                with open(path, "r") as fid:
         | 
| 94 | 
            +
                    while True:
         | 
| 95 | 
            +
                        line = fid.readline()
         | 
| 96 | 
            +
                        if not line:
         | 
| 97 | 
            +
                            break
         | 
| 98 | 
            +
                        line = line.strip()
         | 
| 99 | 
            +
                        if len(line) > 0 and line[0] != "#":
         | 
| 100 | 
            +
                            num_points += 1
         | 
| 101 | 
            +
             | 
| 102 | 
            +
             | 
| 103 | 
            +
                xyzs = np.empty((num_points, 3))
         | 
| 104 | 
            +
                rgbs = np.empty((num_points, 3))
         | 
| 105 | 
            +
                errors = np.empty((num_points, 1))
         | 
| 106 | 
            +
                count = 0
         | 
| 107 | 
            +
                with open(path, "r") as fid:
         | 
| 108 | 
            +
                    while True:
         | 
| 109 | 
            +
                        line = fid.readline()
         | 
| 110 | 
            +
                        if not line:
         | 
| 111 | 
            +
                            break
         | 
| 112 | 
            +
                        line = line.strip()
         | 
| 113 | 
            +
                        if len(line) > 0 and line[0] != "#":
         | 
| 114 | 
            +
                            elems = line.split()
         | 
| 115 | 
            +
                            xyz = np.array(tuple(map(float, elems[1:4])))
         | 
| 116 | 
            +
                            rgb = np.array(tuple(map(int, elems[4:7])))
         | 
| 117 | 
            +
                            error = np.array(float(elems[7]))
         | 
| 118 | 
            +
                            xyzs[count] = xyz
         | 
| 119 | 
            +
                            rgbs[count] = rgb
         | 
| 120 | 
            +
                            errors[count] = error
         | 
| 121 | 
            +
                            count += 1
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                return xyzs, rgbs, errors
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            def read_points3D_binary(path_to_model_file):
         | 
| 126 | 
            +
                """
         | 
| 127 | 
            +
                see: src/base/reconstruction.cc
         | 
| 128 | 
            +
                    void Reconstruction::ReadPoints3DBinary(const std::string& path)
         | 
| 129 | 
            +
                    void Reconstruction::WritePoints3DBinary(const std::string& path)
         | 
| 130 | 
            +
                """
         | 
| 131 | 
            +
             | 
| 132 | 
            +
             | 
| 133 | 
            +
                with open(path_to_model_file, "rb") as fid:
         | 
| 134 | 
            +
                    num_points = read_next_bytes(fid, 8, "Q")[0]
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    xyzs = np.empty((num_points, 3))
         | 
| 137 | 
            +
                    rgbs = np.empty((num_points, 3))
         | 
| 138 | 
            +
                    errors = np.empty((num_points, 1))
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    for p_id in range(num_points):
         | 
| 141 | 
            +
                        binary_point_line_properties = read_next_bytes(
         | 
| 142 | 
            +
                            fid, num_bytes=43, format_char_sequence="QdddBBBd")
         | 
| 143 | 
            +
                        xyz = np.array(binary_point_line_properties[1:4])
         | 
| 144 | 
            +
                        rgb = np.array(binary_point_line_properties[4:7])
         | 
| 145 | 
            +
                        error = np.array(binary_point_line_properties[7])
         | 
| 146 | 
            +
                        track_length = read_next_bytes(
         | 
| 147 | 
            +
                            fid, num_bytes=8, format_char_sequence="Q")[0]
         | 
| 148 | 
            +
                        track_elems = read_next_bytes(
         | 
| 149 | 
            +
                            fid, num_bytes=8*track_length,
         | 
| 150 | 
            +
                            format_char_sequence="ii"*track_length)
         | 
| 151 | 
            +
                        xyzs[p_id] = xyz
         | 
| 152 | 
            +
                        rgbs[p_id] = rgb
         | 
| 153 | 
            +
                        errors[p_id] = error
         | 
| 154 | 
            +
                return xyzs, rgbs, errors
         | 
| 155 | 
            +
             | 
| 156 | 
            +
            def read_intrinsics_text(path):
         | 
| 157 | 
            +
                """
         | 
| 158 | 
            +
                Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
         | 
| 159 | 
            +
                """
         | 
| 160 | 
            +
                cameras = {}
         | 
| 161 | 
            +
                with open(path, "r") as fid:
         | 
| 162 | 
            +
                    while True:
         | 
| 163 | 
            +
                        line = fid.readline()
         | 
| 164 | 
            +
                        if not line:
         | 
| 165 | 
            +
                            break
         | 
| 166 | 
            +
                        line = line.strip()
         | 
| 167 | 
            +
                        if len(line) > 0 and line[0] != "#":
         | 
| 168 | 
            +
                            elems = line.split()
         | 
| 169 | 
            +
                            camera_id = int(elems[0])
         | 
| 170 | 
            +
                            model = elems[1]
         | 
| 171 | 
            +
                            assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
         | 
| 172 | 
            +
                            width = int(elems[2])
         | 
| 173 | 
            +
                            height = int(elems[3])
         | 
| 174 | 
            +
                            params = np.array(tuple(map(float, elems[4:])))
         | 
| 175 | 
            +
                            cameras[camera_id] = Camera(id=camera_id, model=model,
         | 
| 176 | 
            +
                                                        width=width, height=height,
         | 
| 177 | 
            +
                                                        params=params)
         | 
| 178 | 
            +
                return cameras
         | 
| 179 | 
            +
             | 
| 180 | 
            +
            def read_extrinsics_binary(path_to_model_file):
         | 
| 181 | 
            +
                """
         | 
| 182 | 
            +
                see: src/base/reconstruction.cc
         | 
| 183 | 
            +
                    void Reconstruction::ReadImagesBinary(const std::string& path)
         | 
| 184 | 
            +
                    void Reconstruction::WriteImagesBinary(const std::string& path)
         | 
| 185 | 
            +
                """
         | 
| 186 | 
            +
                images = {}
         | 
| 187 | 
            +
                with open(path_to_model_file, "rb") as fid:
         | 
| 188 | 
            +
                    num_reg_images = read_next_bytes(fid, 8, "Q")[0]
         | 
| 189 | 
            +
                    for _ in range(num_reg_images):
         | 
| 190 | 
            +
                        binary_image_properties = read_next_bytes(
         | 
| 191 | 
            +
                            fid, num_bytes=64, format_char_sequence="idddddddi")
         | 
| 192 | 
            +
                        image_id = binary_image_properties[0]
         | 
| 193 | 
            +
                        qvec = np.array(binary_image_properties[1:5])
         | 
| 194 | 
            +
                        tvec = np.array(binary_image_properties[5:8])
         | 
| 195 | 
            +
                        camera_id = binary_image_properties[8]
         | 
| 196 | 
            +
                        image_name = ""
         | 
| 197 | 
            +
                        current_char = read_next_bytes(fid, 1, "c")[0]
         | 
| 198 | 
            +
                        while current_char != b"\x00":   # look for the ASCII 0 entry
         | 
| 199 | 
            +
                            image_name += current_char.decode("utf-8")
         | 
| 200 | 
            +
                            current_char = read_next_bytes(fid, 1, "c")[0]
         | 
| 201 | 
            +
                        num_points2D = read_next_bytes(fid, num_bytes=8,
         | 
| 202 | 
            +
                                                       format_char_sequence="Q")[0]
         | 
| 203 | 
            +
                        x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
         | 
| 204 | 
            +
                                                   format_char_sequence="ddq"*num_points2D)
         | 
| 205 | 
            +
                        xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
         | 
| 206 | 
            +
                                               tuple(map(float, x_y_id_s[1::3]))])
         | 
| 207 | 
            +
                        point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
         | 
| 208 | 
            +
                        images[image_id] = Image(
         | 
| 209 | 
            +
                            id=image_id, qvec=qvec, tvec=tvec,
         | 
| 210 | 
            +
                            camera_id=camera_id, name=image_name,
         | 
| 211 | 
            +
                            xys=xys, point3D_ids=point3D_ids)
         | 
| 212 | 
            +
                return images
         | 
| 213 | 
            +
             | 
| 214 | 
            +
             | 
| 215 | 
            +
            def read_intrinsics_binary(path_to_model_file):
         | 
| 216 | 
            +
                """
         | 
| 217 | 
            +
                see: src/base/reconstruction.cc
         | 
| 218 | 
            +
                    void Reconstruction::WriteCamerasBinary(const std::string& path)
         | 
| 219 | 
            +
                    void Reconstruction::ReadCamerasBinary(const std::string& path)
         | 
| 220 | 
            +
                """
         | 
| 221 | 
            +
                cameras = {}
         | 
| 222 | 
            +
                with open(path_to_model_file, "rb") as fid:
         | 
| 223 | 
            +
                    num_cameras = read_next_bytes(fid, 8, "Q")[0]
         | 
| 224 | 
            +
                    for _ in range(num_cameras):
         | 
| 225 | 
            +
                        camera_properties = read_next_bytes(
         | 
| 226 | 
            +
                            fid, num_bytes=24, format_char_sequence="iiQQ")
         | 
| 227 | 
            +
                        camera_id = camera_properties[0]
         | 
| 228 | 
            +
                        model_id = camera_properties[1]
         | 
| 229 | 
            +
                        model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
         | 
| 230 | 
            +
                        width = camera_properties[2]
         | 
| 231 | 
            +
                        height = camera_properties[3]
         | 
| 232 | 
            +
                        num_params = CAMERA_MODEL_IDS[model_id].num_params
         | 
| 233 | 
            +
                        params = read_next_bytes(fid, num_bytes=8*num_params,
         | 
| 234 | 
            +
                                                 format_char_sequence="d"*num_params)
         | 
| 235 | 
            +
                        cameras[camera_id] = Camera(id=camera_id,
         | 
| 236 | 
            +
                                                    model=model_name,
         | 
| 237 | 
            +
                                                    width=width,
         | 
| 238 | 
            +
                                                    height=height,
         | 
| 239 | 
            +
                                                    params=np.array(params))
         | 
| 240 | 
            +
                    assert len(cameras) == num_cameras
         | 
| 241 | 
            +
                return cameras
         | 
| 242 | 
            +
             | 
| 243 | 
            +
             | 
| 244 | 
            +
            def read_extrinsics_text(path):
         | 
| 245 | 
            +
                """
         | 
| 246 | 
            +
                Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
         | 
| 247 | 
            +
                """
         | 
| 248 | 
            +
                images = {}
         | 
| 249 | 
            +
                with open(path, "r") as fid:
         | 
| 250 | 
            +
                    while True:
         | 
| 251 | 
            +
                        line = fid.readline()
         | 
| 252 | 
            +
                        if not line:
         | 
| 253 | 
            +
                            break
         | 
| 254 | 
            +
                        line = line.strip()
         | 
| 255 | 
            +
                        if len(line) > 0 and line[0] != "#":
         | 
| 256 | 
            +
                            elems = line.split()
         | 
| 257 | 
            +
                            image_id = int(elems[0])
         | 
| 258 | 
            +
                            qvec = np.array(tuple(map(float, elems[1:5])))
         | 
| 259 | 
            +
                            tvec = np.array(tuple(map(float, elems[5:8])))
         | 
| 260 | 
            +
                            camera_id = int(elems[8])
         | 
| 261 | 
            +
                            image_name = elems[9]
         | 
| 262 | 
            +
                            elems = fid.readline().split()
         | 
| 263 | 
            +
                            xys = np.column_stack([tuple(map(float, elems[0::3])),
         | 
| 264 | 
            +
                                                   tuple(map(float, elems[1::3]))])
         | 
| 265 | 
            +
                            point3D_ids = np.array(tuple(map(int, elems[2::3])))
         | 
| 266 | 
            +
                            images[image_id] = Image(
         | 
| 267 | 
            +
                                id=image_id, qvec=qvec, tvec=tvec,
         | 
| 268 | 
            +
                                camera_id=camera_id, name=image_name,
         | 
| 269 | 
            +
                                xys=xys, point3D_ids=point3D_ids)
         | 
| 270 | 
            +
                return images
         | 
| 271 | 
            +
             | 
| 272 | 
            +
             | 
| 273 | 
            +
            def read_colmap_bin_array(path):
         | 
| 274 | 
            +
                """
         | 
| 275 | 
            +
                Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                :param path: path to the colmap binary file.
         | 
| 278 | 
            +
                :return: nd array with the floating point values in the value
         | 
| 279 | 
            +
                """
         | 
| 280 | 
            +
                with open(path, "rb") as fid:
         | 
| 281 | 
            +
                    width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
         | 
| 282 | 
            +
                                                            usecols=(0, 1, 2), dtype=int)
         | 
| 283 | 
            +
                    fid.seek(0)
         | 
| 284 | 
            +
                    num_delimiter = 0
         | 
| 285 | 
            +
                    byte = fid.read(1)
         | 
| 286 | 
            +
                    while True:
         | 
| 287 | 
            +
                        if byte == b"&":
         | 
| 288 | 
            +
                            num_delimiter += 1
         | 
| 289 | 
            +
                            if num_delimiter >= 3:
         | 
| 290 | 
            +
                                break
         | 
| 291 | 
            +
                        byte = fid.read(1)
         | 
| 292 | 
            +
                    array = np.fromfile(fid, np.float32)
         | 
| 293 | 
            +
                array = array.reshape((width, height, channels), order="F")
         | 
| 294 | 
            +
                return np.transpose(array, (1, 0, 2)).squeeze()
         | 
    	
        scene/dataset_readers.py
    ADDED
    
    | @@ -0,0 +1,363 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            # Copyright (C) 2023, Inria
         | 
| 3 | 
            +
            # GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
            # All rights reserved.
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            # This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
            # under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # For inquiries contact  [email protected]
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import os
         | 
| 13 | 
            +
            import sys
         | 
| 14 | 
            +
            from PIL import Image
         | 
| 15 | 
            +
            from typing import NamedTuple
         | 
| 16 | 
            +
            from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
         | 
| 17 | 
            +
                read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
         | 
| 18 | 
            +
            from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
         | 
| 19 | 
            +
            import numpy as np
         | 
| 20 | 
            +
            import json
         | 
| 21 | 
            +
            from pathlib import Path
         | 
| 22 | 
            +
            from plyfile import PlyData, PlyElement
         | 
| 23 | 
            +
            from utils.sh_utils import SH2RGB
         | 
| 24 | 
            +
            from scene.gaussian_model import BasicPointCloud
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            class CameraInfo(NamedTuple):
         | 
| 27 | 
            +
                uid: int
         | 
| 28 | 
            +
                R: np.array
         | 
| 29 | 
            +
                T: np.array
         | 
| 30 | 
            +
                FovY: np.array
         | 
| 31 | 
            +
                FovX: np.array
         | 
| 32 | 
            +
                image: np.array
         | 
| 33 | 
            +
                image_path: str
         | 
| 34 | 
            +
                image_name: str
         | 
| 35 | 
            +
                width: int
         | 
| 36 | 
            +
                height: int
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            class SceneInfo(NamedTuple):
         | 
| 40 | 
            +
                point_cloud: BasicPointCloud
         | 
| 41 | 
            +
                train_cameras: list
         | 
| 42 | 
            +
                test_cameras: list
         | 
| 43 | 
            +
                nerf_normalization: dict
         | 
| 44 | 
            +
                ply_path: str
         | 
| 45 | 
            +
                train_poses: list
         | 
| 46 | 
            +
                test_poses: list
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            def getNerfppNorm(cam_info):
         | 
| 49 | 
            +
                def get_center_and_diag(cam_centers):
         | 
| 50 | 
            +
                    cam_centers = np.hstack(cam_centers)
         | 
| 51 | 
            +
                    avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
         | 
| 52 | 
            +
                    center = avg_cam_center
         | 
| 53 | 
            +
                    dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
         | 
| 54 | 
            +
                    diagonal = np.max(dist)
         | 
| 55 | 
            +
                    return center.flatten(), diagonal
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                cam_centers = []
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                for cam in cam_info:
         | 
| 60 | 
            +
                    W2C = getWorld2View2(cam.R, cam.T)
         | 
| 61 | 
            +
                    C2W = np.linalg.inv(W2C)
         | 
| 62 | 
            +
                    cam_centers.append(C2W[:3, 3:4])
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                center, diagonal = get_center_and_diag(cam_centers)
         | 
| 65 | 
            +
                radius = diagonal * 1.1
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                translate = -center
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                return {"translate": translate, "radius": radius}
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder, eval):
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                cam_infos = []
         | 
| 74 | 
            +
                poses=[]
         | 
| 75 | 
            +
                for idx, key in enumerate(cam_extrinsics):
         | 
| 76 | 
            +
                    sys.stdout.write('\r')
         | 
| 77 | 
            +
                    # the exact output you're looking for:
         | 
| 78 | 
            +
                    sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
         | 
| 79 | 
            +
                    sys.stdout.flush()
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    if eval:
         | 
| 82 | 
            +
                        extr = cam_extrinsics[key]
         | 
| 83 | 
            +
                        intr = cam_intrinsics[1]
         | 
| 84 | 
            +
                        uid = idx+1
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    else:
         | 
| 87 | 
            +
                        extr = cam_extrinsics[key]
         | 
| 88 | 
            +
                        intr = cam_intrinsics[extr.camera_id]
         | 
| 89 | 
            +
                        uid = intr.id
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    height = intr.height
         | 
| 92 | 
            +
                    width = intr.width            
         | 
| 93 | 
            +
                    R = np.transpose(qvec2rotmat(extr.qvec))
         | 
| 94 | 
            +
                    T = np.array(extr.tvec)
         | 
| 95 | 
            +
                    pose =  np.vstack((np.hstack((R, T.reshape(3,-1))),np.array([[0, 0, 0, 1]])))
         | 
| 96 | 
            +
                    poses.append(pose)
         | 
| 97 | 
            +
                    if intr.model=="SIMPLE_PINHOLE":
         | 
| 98 | 
            +
                        focal_length_x = intr.params[0]
         | 
| 99 | 
            +
                        FovY = focal2fov(focal_length_x, height)
         | 
| 100 | 
            +
                        FovX = focal2fov(focal_length_x, width)
         | 
| 101 | 
            +
                    elif intr.model=="PINHOLE":
         | 
| 102 | 
            +
                        focal_length_x = intr.params[0]
         | 
| 103 | 
            +
                        focal_length_y = intr.params[1]
         | 
| 104 | 
            +
                        FovY = focal2fov(focal_length_y, height)
         | 
| 105 | 
            +
                        FovX = focal2fov(focal_length_x, width)
         | 
| 106 | 
            +
                    else:
         | 
| 107 | 
            +
                        assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
         | 
| 108 | 
            +
             | 
| 109 | 
            +
             | 
| 110 | 
            +
                    if eval:
         | 
| 111 | 
            +
                        tmp = os.path.dirname(os.path.dirname(os.path.join(images_folder)))
         | 
| 112 | 
            +
                        all_images_folder = os.path.join(tmp, 'images')
         | 
| 113 | 
            +
                        image_path = os.path.join(all_images_folder, os.path.basename(extr.name))
         | 
| 114 | 
            +
                    else:
         | 
| 115 | 
            +
                        image_path = os.path.join(images_folder, os.path.basename(extr.name))
         | 
| 116 | 
            +
                    image_name = os.path.basename(image_path).split(".")[0]
         | 
| 117 | 
            +
                    image = Image.open(image_path)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
             | 
| 120 | 
            +
                    cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
         | 
| 121 | 
            +
                                          image_path=image_path, image_name=image_name, width=width, height=height)
         | 
| 122 | 
            +
                
         | 
| 123 | 
            +
                    cam_infos.append(cam_info)
         | 
| 124 | 
            +
                sys.stdout.write('\n')
         | 
| 125 | 
            +
                return cam_infos, poses
         | 
| 126 | 
            +
             | 
| 127 | 
            +
            # For interpolated video, open when only render interpolated video
         | 
| 128 | 
            +
            def readColmapCamerasInterp(cam_extrinsics, cam_intrinsics, images_folder, model_path):
         | 
| 129 | 
            +
                
         | 
| 130 | 
            +
                pose_interpolated_path = model_path + 'pose/pose_interpolated.npy'
         | 
| 131 | 
            +
                pose_interpolated = np.load(pose_interpolated_path)
         | 
| 132 | 
            +
                intr = cam_intrinsics[1]
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                cam_infos = []
         | 
| 135 | 
            +
                poses=[]
         | 
| 136 | 
            +
                for idx, pose_npy in enumerate(pose_interpolated):
         | 
| 137 | 
            +
                    sys.stdout.write('\r')
         | 
| 138 | 
            +
                    sys.stdout.write("Reading camera {}/{}".format(idx+1, pose_interpolated.shape[0]))
         | 
| 139 | 
            +
                    sys.stdout.flush()
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    extr = pose_npy
         | 
| 142 | 
            +
                    intr = intr
         | 
| 143 | 
            +
                    height = intr.height
         | 
| 144 | 
            +
                    width = intr.width
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    uid = idx
         | 
| 147 | 
            +
                    R = extr[:3, :3].transpose()
         | 
| 148 | 
            +
                    T = extr[:3, 3]
         | 
| 149 | 
            +
                    pose =  np.vstack((np.hstack((R, T.reshape(3,-1))),np.array([[0, 0, 0, 1]])))
         | 
| 150 | 
            +
                    # print(uid)
         | 
| 151 | 
            +
                    # print(pose.shape)
         | 
| 152 | 
            +
                    # pose = np.linalg.inv(pose)
         | 
| 153 | 
            +
                    poses.append(pose)
         | 
| 154 | 
            +
                    if intr.model=="SIMPLE_PINHOLE":
         | 
| 155 | 
            +
                        focal_length_x = intr.params[0]
         | 
| 156 | 
            +
                        FovY = focal2fov(focal_length_x, height)
         | 
| 157 | 
            +
                        FovX = focal2fov(focal_length_x, width)
         | 
| 158 | 
            +
                    elif intr.model=="PINHOLE":
         | 
| 159 | 
            +
                        focal_length_x = intr.params[0]
         | 
| 160 | 
            +
                        focal_length_y = intr.params[1]
         | 
| 161 | 
            +
                        FovY = focal2fov(focal_length_y, height)
         | 
| 162 | 
            +
                        FovX = focal2fov(focal_length_x, width)
         | 
| 163 | 
            +
                    else:
         | 
| 164 | 
            +
                        assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    images_list = os.listdir(os.path.join(images_folder))
         | 
| 167 | 
            +
                    image_name_0 = images_list[0]
         | 
| 168 | 
            +
                    image_name = str(idx).zfill(4)
         | 
| 169 | 
            +
                    image = Image.open(images_folder + '/' + image_name_0)
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
         | 
| 172 | 
            +
                                          image_path=images_folder, image_name=image_name, width=width, height=height)
         | 
| 173 | 
            +
                    cam_infos.append(cam_info)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                sys.stdout.write('\n')
         | 
| 176 | 
            +
                return cam_infos, poses
         | 
| 177 | 
            +
             | 
| 178 | 
            +
             | 
| 179 | 
            +
            def fetchPly(path):
         | 
| 180 | 
            +
                plydata = PlyData.read(path)
         | 
| 181 | 
            +
                vertices = plydata['vertex']
         | 
| 182 | 
            +
                positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
         | 
| 183 | 
            +
                colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
         | 
| 184 | 
            +
                normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
         | 
| 185 | 
            +
                return BasicPointCloud(points=positions, colors=colors, normals=normals)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
            def storePly(path, xyz, rgb):
         | 
| 188 | 
            +
                # Define the dtype for the structured array
         | 
| 189 | 
            +
                dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
         | 
| 190 | 
            +
                        ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
         | 
| 191 | 
            +
                        ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
         | 
| 192 | 
            +
                
         | 
| 193 | 
            +
                normals = np.zeros_like(xyz)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                elements = np.empty(xyz.shape[0], dtype=dtype)
         | 
| 196 | 
            +
                attributes = np.concatenate((xyz, normals, rgb), axis=1)
         | 
| 197 | 
            +
                elements[:] = list(map(tuple, attributes))
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                # Create the PlyData object and write to file
         | 
| 200 | 
            +
                vertex_element = PlyElement.describe(elements, 'vertex')
         | 
| 201 | 
            +
                ply_data = PlyData([vertex_element])
         | 
| 202 | 
            +
                ply_data.write(path)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
            def readColmapSceneInfo(path, images, eval, args, opt, llffhold=2):
         | 
| 205 | 
            +
                # try:
         | 
| 206 | 
            +
                #     cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
         | 
| 207 | 
            +
                #     cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
         | 
| 208 | 
            +
                #     cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
         | 
| 209 | 
            +
                #     cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
         | 
| 210 | 
            +
                # except:
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                ##### For initializing test pose using PCD_Registration
         | 
| 213 | 
            +
                if eval and opt.get_video==False:    
         | 
| 214 | 
            +
                    print("Loading initial test pose for evaluation.")
         | 
| 215 | 
            +
                    cameras_extrinsic_file = os.path.join(path, "init_test_pose/sparse/0", "images.txt")
         | 
| 216 | 
            +
                else:
         | 
| 217 | 
            +
                    cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
         | 
| 220 | 
            +
                cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
         | 
| 221 | 
            +
                cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                reading_dir = "images" if images == None else images
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                if opt.get_video:
         | 
| 226 | 
            +
                    cam_infos_unsorted, poses = readColmapCamerasInterp(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir), model_path=args.model_path)
         | 
| 227 | 
            +
                else:
         | 
| 228 | 
            +
                    cam_infos_unsorted, poses = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir), eval=eval)
         | 
| 229 | 
            +
                sorting_indices = sorted(range(len(cam_infos_unsorted)), key=lambda x: cam_infos_unsorted[x].image_name)
         | 
| 230 | 
            +
                cam_infos = [cam_infos_unsorted[i] for i in sorting_indices]
         | 
| 231 | 
            +
                sorted_poses = [poses[i] for i in sorting_indices]
         | 
| 232 | 
            +
                cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                if eval:
         | 
| 235 | 
            +
                    # train_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx+1) % llffhold != 0]
         | 
| 236 | 
            +
                    # test_cam_infos = [c for idx, c in enumerate(cam_infos) if (idx+1) % llffhold == 0]
         | 
| 237 | 
            +
                    # train_poses = [c for idx, c in enumerate(sorted_poses) if (idx+1) % llffhold != 0]
         | 
| 238 | 
            +
                    # test_poses = [c for idx, c in enumerate(sorted_poses) if (idx+1) % llffhold == 0]
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    train_cam_infos = cam_infos
         | 
| 241 | 
            +
                    test_cam_infos = cam_infos
         | 
| 242 | 
            +
                    train_poses = sorted_poses
         | 
| 243 | 
            +
                    test_poses = sorted_poses
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                else:
         | 
| 246 | 
            +
                    train_cam_infos = cam_infos
         | 
| 247 | 
            +
                    test_cam_infos = []
         | 
| 248 | 
            +
                    train_poses = sorted_poses
         | 
| 249 | 
            +
                    test_poses = []
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                nerf_normalization = getNerfppNorm(train_cam_infos)
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                ply_path = os.path.join(path, "sparse/0/points3D.ply")
         | 
| 254 | 
            +
                bin_path = os.path.join(path, "sparse/0/points3D.bin")
         | 
| 255 | 
            +
                txt_path = os.path.join(path, "sparse/0/points3D.txt")
         | 
| 256 | 
            +
                if not os.path.exists(ply_path):
         | 
| 257 | 
            +
                    print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
         | 
| 258 | 
            +
                    try:
         | 
| 259 | 
            +
                        xyz, rgb, _ = read_points3D_binary(bin_path)
         | 
| 260 | 
            +
                    except:
         | 
| 261 | 
            +
                        xyz, rgb, _ = read_points3D_text(txt_path)
         | 
| 262 | 
            +
                    storePly(ply_path, xyz, rgb)
         | 
| 263 | 
            +
                try:
         | 
| 264 | 
            +
                    pcd = fetchPly(ply_path)
         | 
| 265 | 
            +
                except:
         | 
| 266 | 
            +
                    pcd = None
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                # np.save("poses_family.npy", sorted_poses)
         | 
| 269 | 
            +
                # breakpoint()
         | 
| 270 | 
            +
                # np.save("3dpoints.npy", pcd.points)
         | 
| 271 | 
            +
                # np.save("3dcolors.npy", pcd.colors)
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                scene_info = SceneInfo(point_cloud=pcd,
         | 
| 274 | 
            +
                                       train_cameras=train_cam_infos,
         | 
| 275 | 
            +
                                       test_cameras=test_cam_infos,
         | 
| 276 | 
            +
                                       nerf_normalization=nerf_normalization,
         | 
| 277 | 
            +
                                       ply_path=ply_path,
         | 
| 278 | 
            +
                                       train_poses=train_poses,
         | 
| 279 | 
            +
                                       test_poses=test_poses)
         | 
| 280 | 
            +
                return scene_info
         | 
| 281 | 
            +
             | 
| 282 | 
            +
            def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
         | 
| 283 | 
            +
                cam_infos = []
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                with open(os.path.join(path, transformsfile)) as json_file:
         | 
| 286 | 
            +
                    contents = json.load(json_file)
         | 
| 287 | 
            +
                    fovx = contents["camera_angle_x"]
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                    frames = contents["frames"]
         | 
| 290 | 
            +
                    for idx, frame in enumerate(frames):
         | 
| 291 | 
            +
                        cam_name = os.path.join(path, frame["file_path"] + extension)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                        # NeRF 'transform_matrix' is a camera-to-world transform
         | 
| 294 | 
            +
                        c2w = np.array(frame["transform_matrix"])
         | 
| 295 | 
            +
                        # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
         | 
| 296 | 
            +
                        c2w[:3, 1:3] *= -1
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                        # get the world-to-camera transform and set R, T
         | 
| 299 | 
            +
                        w2c = np.linalg.inv(c2w)
         | 
| 300 | 
            +
                        R = np.transpose(w2c[:3,:3])  # R is stored transposed due to 'glm' in CUDA code
         | 
| 301 | 
            +
                        T = w2c[:3, 3]
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                        image_path = os.path.join(path, cam_name)
         | 
| 304 | 
            +
                        image_name = Path(cam_name).stem
         | 
| 305 | 
            +
                        image = Image.open(image_path)
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                        im_data = np.array(image.convert("RGBA"))
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                        bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                        norm_data = im_data / 255.0
         | 
| 312 | 
            +
                        arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
         | 
| 313 | 
            +
                        image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                        fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
         | 
| 316 | 
            +
                        FovY = fovy 
         | 
| 317 | 
            +
                        FovX = fovx
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                        cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
         | 
| 320 | 
            +
                                        image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
         | 
| 321 | 
            +
                        
         | 
| 322 | 
            +
                return cam_infos
         | 
| 323 | 
            +
             | 
| 324 | 
            +
            def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
         | 
| 325 | 
            +
                print("Reading Training Transforms")
         | 
| 326 | 
            +
                train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
         | 
| 327 | 
            +
                print("Reading Test Transforms")
         | 
| 328 | 
            +
                test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
         | 
| 329 | 
            +
                
         | 
| 330 | 
            +
                if not eval:
         | 
| 331 | 
            +
                    train_cam_infos.extend(test_cam_infos)
         | 
| 332 | 
            +
                    test_cam_infos = []
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                nerf_normalization = getNerfppNorm(train_cam_infos)
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                ply_path = os.path.join(path, "points3d.ply")
         | 
| 337 | 
            +
                if not os.path.exists(ply_path):
         | 
| 338 | 
            +
                    # Since this data set has no colmap data, we start with random points
         | 
| 339 | 
            +
                    num_pts = 100_000
         | 
| 340 | 
            +
                    print(f"Generating random point cloud ({num_pts})...")
         | 
| 341 | 
            +
                    
         | 
| 342 | 
            +
                    # We create random points inside the bounds of the synthetic Blender scenes
         | 
| 343 | 
            +
                    xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
         | 
| 344 | 
            +
                    shs = np.random.random((num_pts, 3)) / 255.0
         | 
| 345 | 
            +
                    pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    storePly(ply_path, xyz, SH2RGB(shs) * 255)
         | 
| 348 | 
            +
                try:
         | 
| 349 | 
            +
                    pcd = fetchPly(ply_path)
         | 
| 350 | 
            +
                except:
         | 
| 351 | 
            +
                    pcd = None
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                scene_info = SceneInfo(point_cloud=pcd,
         | 
| 354 | 
            +
                                       train_cameras=train_cam_infos,
         | 
| 355 | 
            +
                                       test_cameras=test_cam_infos,
         | 
| 356 | 
            +
                                       nerf_normalization=nerf_normalization,
         | 
| 357 | 
            +
                                       ply_path=ply_path)
         | 
| 358 | 
            +
                return scene_info
         | 
| 359 | 
            +
             | 
| 360 | 
            +
            sceneLoadTypeCallbacks = {
         | 
| 361 | 
            +
                "Colmap": readColmapSceneInfo,
         | 
| 362 | 
            +
                "Blender" : readNerfSyntheticInfo
         | 
| 363 | 
            +
            }
         | 
    	
        scene/gaussian_model.py
    ADDED
    
    | @@ -0,0 +1,502 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            # Copyright (C) 2023, Inria
         | 
| 3 | 
            +
            # GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
            # All rights reserved.
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            # This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
            # under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # For inquiries contact  [email protected]
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            # from lietorch import SO3, SE3, Sim3, LieGroupParameter
         | 
| 14 | 
            +
            import numpy as np
         | 
| 15 | 
            +
            from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
         | 
| 16 | 
            +
            from torch import nn
         | 
| 17 | 
            +
            import os
         | 
| 18 | 
            +
            from utils.system_utils import mkdir_p
         | 
| 19 | 
            +
            from plyfile import PlyData, PlyElement
         | 
| 20 | 
            +
            from utils.sh_utils import RGB2SH
         | 
| 21 | 
            +
            from simple_knn._C import distCUDA2
         | 
| 22 | 
            +
            from utils.graphics_utils import BasicPointCloud
         | 
| 23 | 
            +
            from utils.general_utils import strip_symmetric, build_scaling_rotation
         | 
| 24 | 
            +
            from scipy.spatial.transform import Rotation as R
         | 
| 25 | 
            +
            from utils.pose_utils import rotation2quad, get_tensor_from_camera
         | 
| 26 | 
            +
            from utils.graphics_utils import getWorld2View2
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            def quaternion_to_rotation_matrix(quaternion):
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
                Convert a quaternion to a rotation matrix.
         | 
| 31 | 
            +
                
         | 
| 32 | 
            +
                Parameters:
         | 
| 33 | 
            +
                - quaternion: A tensor of shape (..., 4) representing quaternions.
         | 
| 34 | 
            +
                
         | 
| 35 | 
            +
                Returns:
         | 
| 36 | 
            +
                - A tensor of shape (..., 3, 3) representing rotation matrices.
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
                # Ensure quaternion is of float type for computation
         | 
| 39 | 
            +
                quaternion = quaternion.float()
         | 
| 40 | 
            +
                
         | 
| 41 | 
            +
                # Normalize the quaternion to unit length
         | 
| 42 | 
            +
                quaternion = quaternion / quaternion.norm(p=2, dim=-1, keepdim=True)
         | 
| 43 | 
            +
                
         | 
| 44 | 
            +
                # Extract components
         | 
| 45 | 
            +
                w, x, y, z = quaternion[..., 0], quaternion[..., 1], quaternion[..., 2], quaternion[..., 3]
         | 
| 46 | 
            +
                
         | 
| 47 | 
            +
                # Compute rotation matrix components
         | 
| 48 | 
            +
                xx, yy, zz = x * x, y * y, z * z
         | 
| 49 | 
            +
                xy, xz, yz = x * y, x * z, y * z
         | 
| 50 | 
            +
                xw, yw, zw = x * w, y * w, z * w
         | 
| 51 | 
            +
                
         | 
| 52 | 
            +
                # Assemble the rotation matrix
         | 
| 53 | 
            +
                R = torch.stack([
         | 
| 54 | 
            +
                    torch.stack([1 - 2 * (yy + zz),     2 * (xy - zw),     2 * (xz + yw)], dim=-1),
         | 
| 55 | 
            +
                    torch.stack([    2 * (xy + zw), 1 - 2 * (xx + zz),     2 * (yz - xw)], dim=-1),
         | 
| 56 | 
            +
                    torch.stack([    2 * (xz - yw),     2 * (yz + xw), 1 - 2 * (xx + yy)], dim=-1)
         | 
| 57 | 
            +
                ], dim=-2)
         | 
| 58 | 
            +
                
         | 
| 59 | 
            +
                return R
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            class GaussianModel:
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                def setup_functions(self):
         | 
| 65 | 
            +
                    def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
         | 
| 66 | 
            +
                        L = build_scaling_rotation(scaling_modifier * scaling, rotation)
         | 
| 67 | 
            +
                        actual_covariance = L @ L.transpose(1, 2)
         | 
| 68 | 
            +
                        symm = strip_symmetric(actual_covariance)
         | 
| 69 | 
            +
                        return symm
         | 
| 70 | 
            +
                    
         | 
| 71 | 
            +
                    self.scaling_activation = torch.exp
         | 
| 72 | 
            +
                    self.scaling_inverse_activation = torch.log
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    self.covariance_activation = build_covariance_from_scaling_rotation
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    self.opacity_activation = torch.sigmoid
         | 
| 77 | 
            +
                    self.inverse_opacity_activation = inverse_sigmoid
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    self.rotation_activation = torch.nn.functional.normalize
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
                def __init__(self, sh_degree : int):
         | 
| 83 | 
            +
                    self.active_sh_degree = 0
         | 
| 84 | 
            +
                    self.max_sh_degree = sh_degree  
         | 
| 85 | 
            +
                    self._xyz = torch.empty(0)
         | 
| 86 | 
            +
                    self._features_dc = torch.empty(0)
         | 
| 87 | 
            +
                    self._features_rest = torch.empty(0)
         | 
| 88 | 
            +
                    self._scaling = torch.empty(0)
         | 
| 89 | 
            +
                    self._rotation = torch.empty(0)
         | 
| 90 | 
            +
                    self._opacity = torch.empty(0)
         | 
| 91 | 
            +
                    self.max_radii2D = torch.empty(0)
         | 
| 92 | 
            +
                    self.xyz_gradient_accum = torch.empty(0)
         | 
| 93 | 
            +
                    self.denom = torch.empty(0)
         | 
| 94 | 
            +
                    self.optimizer = None
         | 
| 95 | 
            +
                    self.percent_dense = 0
         | 
| 96 | 
            +
                    self.spatial_lr_scale = 0
         | 
| 97 | 
            +
                    self.setup_functions()
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                def capture(self):
         | 
| 100 | 
            +
                    return (
         | 
| 101 | 
            +
                        self.active_sh_degree,
         | 
| 102 | 
            +
                        self._xyz,
         | 
| 103 | 
            +
                        self._features_dc,
         | 
| 104 | 
            +
                        self._features_rest,
         | 
| 105 | 
            +
                        self._scaling,
         | 
| 106 | 
            +
                        self._rotation,
         | 
| 107 | 
            +
                        self._opacity,
         | 
| 108 | 
            +
                        self.max_radii2D,
         | 
| 109 | 
            +
                        self.xyz_gradient_accum,
         | 
| 110 | 
            +
                        self.denom,
         | 
| 111 | 
            +
                        self.optimizer.state_dict(),
         | 
| 112 | 
            +
                        self.spatial_lr_scale,
         | 
| 113 | 
            +
                        self.P,
         | 
| 114 | 
            +
                    )
         | 
| 115 | 
            +
                
         | 
| 116 | 
            +
                def restore(self, model_args, training_args):
         | 
| 117 | 
            +
                    (self.active_sh_degree, 
         | 
| 118 | 
            +
                    self._xyz, 
         | 
| 119 | 
            +
                    self._features_dc, 
         | 
| 120 | 
            +
                    self._features_rest,
         | 
| 121 | 
            +
                    self._scaling, 
         | 
| 122 | 
            +
                    self._rotation, 
         | 
| 123 | 
            +
                    self._opacity,
         | 
| 124 | 
            +
                    self.max_radii2D, 
         | 
| 125 | 
            +
                    xyz_gradient_accum, 
         | 
| 126 | 
            +
                    denom,
         | 
| 127 | 
            +
                    opt_dict, 
         | 
| 128 | 
            +
                    self.spatial_lr_scale,
         | 
| 129 | 
            +
                    self.P) = model_args
         | 
| 130 | 
            +
                    self.training_setup(training_args)
         | 
| 131 | 
            +
                    self.xyz_gradient_accum = xyz_gradient_accum
         | 
| 132 | 
            +
                    self.denom = denom
         | 
| 133 | 
            +
                    self.optimizer.load_state_dict(opt_dict)
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                @property
         | 
| 136 | 
            +
                def get_scaling(self):
         | 
| 137 | 
            +
                    return self.scaling_activation(self._scaling)
         | 
| 138 | 
            +
                
         | 
| 139 | 
            +
                @property
         | 
| 140 | 
            +
                def get_rotation(self):
         | 
| 141 | 
            +
                    return self.rotation_activation(self._rotation)
         | 
| 142 | 
            +
                
         | 
| 143 | 
            +
                @property
         | 
| 144 | 
            +
                def get_xyz(self):
         | 
| 145 | 
            +
                    return self._xyz
         | 
| 146 | 
            +
                
         | 
| 147 | 
            +
                def compute_relative_world_to_camera(self, R1, t1, R2, t2):
         | 
| 148 | 
            +
                    # Create a row of zeros with a one at the end, for homogeneous coordinates
         | 
| 149 | 
            +
                    zero_row = np.array([[0, 0, 0, 1]], dtype=np.float32)
         | 
| 150 | 
            +
                    
         | 
| 151 | 
            +
                    # Compute the inverse of the first extrinsic matrix
         | 
| 152 | 
            +
                    E1_inv = np.hstack([R1.T, -R1.T @ t1.reshape(-1, 1)])  # Transpose and reshape for correct dimensions
         | 
| 153 | 
            +
                    E1_inv = np.vstack([E1_inv, zero_row])  # Append the zero_row to make it a 4x4 matrix
         | 
| 154 | 
            +
                    
         | 
| 155 | 
            +
                    # Compute the second extrinsic matrix
         | 
| 156 | 
            +
                    E2 = np.hstack([R2, -R2 @ t2.reshape(-1, 1)])  # No need to transpose R2
         | 
| 157 | 
            +
                    E2 = np.vstack([E2, zero_row])  # Append the zero_row to make it a 4x4 matrix
         | 
| 158 | 
            +
                    
         | 
| 159 | 
            +
                    # Compute the relative transformation
         | 
| 160 | 
            +
                    E_rel = E2 @ E1_inv
         | 
| 161 | 
            +
                    
         | 
| 162 | 
            +
                    return E_rel
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                def init_RT_seq(self, cam_list):        
         | 
| 165 | 
            +
                    poses =[]
         | 
| 166 | 
            +
                    for cam in cam_list[1.0]:
         | 
| 167 | 
            +
                        p = get_tensor_from_camera(cam.world_view_transform.transpose(0, 1)) # R T -> quat t
         | 
| 168 | 
            +
                        poses.append(p)
         | 
| 169 | 
            +
                    poses = torch.stack(poses)
         | 
| 170 | 
            +
                    self.P = poses.cuda().requires_grad_(True)
         | 
| 171 | 
            +
                    
         | 
| 172 | 
            +
                
         | 
| 173 | 
            +
                def get_RT(self, idx):
         | 
| 174 | 
            +
                    pose = self.P[idx]
         | 
| 175 | 
            +
                    return pose
         | 
| 176 | 
            +
                
         | 
| 177 | 
            +
                def get_RT_test(self, idx):
         | 
| 178 | 
            +
                    pose = self.test_P[idx]
         | 
| 179 | 
            +
                    return pose
         | 
| 180 | 
            +
                
         | 
| 181 | 
            +
                @property
         | 
| 182 | 
            +
                def get_features(self):
         | 
| 183 | 
            +
                    features_dc = self._features_dc
         | 
| 184 | 
            +
                    features_rest = self._features_rest
         | 
| 185 | 
            +
                    return torch.cat((features_dc, features_rest), dim=1)
         | 
| 186 | 
            +
                
         | 
| 187 | 
            +
                @property
         | 
| 188 | 
            +
                def get_opacity(self):
         | 
| 189 | 
            +
                    return self.opacity_activation(self._opacity)
         | 
| 190 | 
            +
                
         | 
| 191 | 
            +
                def get_covariance(self, scaling_modifier = 1):
         | 
| 192 | 
            +
                    return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                def oneupSHdegree(self):
         | 
| 195 | 
            +
                    if self.active_sh_degree < self.max_sh_degree:
         | 
| 196 | 
            +
                        self.active_sh_degree += 1
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # gradio
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    self.spatial_lr_scale = spatial_lr_scale
         | 
| 203 | 
            +
                    fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
         | 
| 204 | 
            +
                    fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
         | 
| 205 | 
            +
                    features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
         | 
| 206 | 
            +
                    features[:, :3, 0 ] = fused_color
         | 
| 207 | 
            +
                    features[:, 3:, 1:] = 0.0
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    print("Number of points at initialisation : ", fused_point_cloud.shape[0])
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
         | 
| 212 | 
            +
                    scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
         | 
| 213 | 
            +
                    rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
         | 
| 214 | 
            +
                    rots[:, 0] = 1
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
         | 
| 219 | 
            +
                    self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
         | 
| 220 | 
            +
                    self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
         | 
| 221 | 
            +
                    self._scaling = nn.Parameter(scales.requires_grad_(True))
         | 
| 222 | 
            +
                    self._rotation = nn.Parameter(rots.requires_grad_(True))
         | 
| 223 | 
            +
                    self._opacity = nn.Parameter(opacities.requires_grad_(True))
         | 
| 224 | 
            +
                    self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
         | 
| 225 | 
            +
             | 
| 226 | 
            +
                def training_setup(self, training_args):
         | 
| 227 | 
            +
                    self.percent_dense = training_args.percent_dense
         | 
| 228 | 
            +
                    self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
         | 
| 229 | 
            +
                    self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                    l = [
         | 
| 232 | 
            +
                        {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
         | 
| 233 | 
            +
                        {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
         | 
| 234 | 
            +
                        {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
         | 
| 235 | 
            +
                        {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
         | 
| 236 | 
            +
                        {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
         | 
| 237 | 
            +
                        {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"},
         | 
| 238 | 
            +
                    ]
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    l_cam = [{'params': [self.P],'lr': training_args.rotation_lr*0.1, "name": "pose"},]
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    l += l_cam
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
         | 
| 245 | 
            +
                    self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
         | 
| 246 | 
            +
                                                                lr_final=training_args.position_lr_final*self.spatial_lr_scale,
         | 
| 247 | 
            +
                                                                lr_delay_mult=training_args.position_lr_delay_mult,
         | 
| 248 | 
            +
                                                                max_steps=training_args.position_lr_max_steps)
         | 
| 249 | 
            +
                    self.cam_scheduler_args = get_expon_lr_func(
         | 
| 250 | 
            +
                                                                # lr_init=0,
         | 
| 251 | 
            +
                                                                # lr_final=0,
         | 
| 252 | 
            +
                                                                lr_init=training_args.rotation_lr*0.1,
         | 
| 253 | 
            +
                                                                lr_final=training_args.rotation_lr*0.001,
         | 
| 254 | 
            +
                                                                # lr_init=training_args.position_lr_init*self.spatial_lr_scale*10,
         | 
| 255 | 
            +
                                                                # lr_final=training_args.position_lr_final*self.spatial_lr_scale*10,
         | 
| 256 | 
            +
                                                                lr_delay_mult=training_args.position_lr_delay_mult,
         | 
| 257 | 
            +
                                                                max_steps=1000)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                def update_learning_rate(self, iteration):
         | 
| 260 | 
            +
                    ''' Learning rate scheduling per step '''
         | 
| 261 | 
            +
                    for param_group in self.optimizer.param_groups:
         | 
| 262 | 
            +
                        if param_group["name"] == "pose":
         | 
| 263 | 
            +
                            lr = self.cam_scheduler_args(iteration)
         | 
| 264 | 
            +
                            # print("pose learning rate", iteration, lr)
         | 
| 265 | 
            +
                            param_group['lr'] = lr
         | 
| 266 | 
            +
                        if param_group["name"] == "xyz":
         | 
| 267 | 
            +
                            lr = self.xyz_scheduler_args(iteration)
         | 
| 268 | 
            +
                            param_group['lr'] = lr
         | 
| 269 | 
            +
                    # return lr
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                def construct_list_of_attributes(self):
         | 
| 272 | 
            +
                    l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
         | 
| 273 | 
            +
                    # All channels except the 3 DC
         | 
| 274 | 
            +
                    for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
         | 
| 275 | 
            +
                        l.append('f_dc_{}'.format(i))
         | 
| 276 | 
            +
                    for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
         | 
| 277 | 
            +
                        l.append('f_rest_{}'.format(i))
         | 
| 278 | 
            +
                    l.append('opacity')
         | 
| 279 | 
            +
                    for i in range(self._scaling.shape[1]):
         | 
| 280 | 
            +
                        l.append('scale_{}'.format(i))
         | 
| 281 | 
            +
                    for i in range(self._rotation.shape[1]):
         | 
| 282 | 
            +
                        l.append('rot_{}'.format(i))
         | 
| 283 | 
            +
                    return l
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                def save_ply(self, path):
         | 
| 286 | 
            +
                    mkdir_p(os.path.dirname(path))
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    xyz = self._xyz.detach().cpu().numpy()
         | 
| 289 | 
            +
                    normals = np.zeros_like(xyz)
         | 
| 290 | 
            +
                    f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
         | 
| 291 | 
            +
                    f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
         | 
| 292 | 
            +
                    opacities = self._opacity.detach().cpu().numpy()
         | 
| 293 | 
            +
                    scale = self._scaling.detach().cpu().numpy()
         | 
| 294 | 
            +
                    rotation = self._rotation.detach().cpu().numpy()
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                    dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    elements = np.empty(xyz.shape[0], dtype=dtype_full)
         | 
| 299 | 
            +
                    attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
         | 
| 300 | 
            +
                    elements[:] = list(map(tuple, attributes))
         | 
| 301 | 
            +
                    el = PlyElement.describe(elements, 'vertex')
         | 
| 302 | 
            +
                    PlyData([el]).write(path)
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                def reset_opacity(self):
         | 
| 305 | 
            +
                    opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
         | 
| 306 | 
            +
                    optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
         | 
| 307 | 
            +
                    self._opacity = optimizable_tensors["opacity"]
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                def load_ply(self, path):
         | 
| 310 | 
            +
                    plydata = PlyData.read(path)
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                    xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
         | 
| 313 | 
            +
                                    np.asarray(plydata.elements[0]["y"]),
         | 
| 314 | 
            +
                                    np.asarray(plydata.elements[0]["z"])),  axis=1)
         | 
| 315 | 
            +
                    opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    features_dc = np.zeros((xyz.shape[0], 3, 1))
         | 
| 318 | 
            +
                    features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
         | 
| 319 | 
            +
                    features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
         | 
| 320 | 
            +
                    features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
         | 
| 323 | 
            +
                    extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
         | 
| 324 | 
            +
                    assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
         | 
| 325 | 
            +
                    features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
         | 
| 326 | 
            +
                    for idx, attr_name in enumerate(extra_f_names):
         | 
| 327 | 
            +
                        features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
         | 
| 328 | 
            +
                    # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
         | 
| 329 | 
            +
                    features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
         | 
| 332 | 
            +
                    scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
         | 
| 333 | 
            +
                    scales = np.zeros((xyz.shape[0], len(scale_names)))
         | 
| 334 | 
            +
                    for idx, attr_name in enumerate(scale_names):
         | 
| 335 | 
            +
                        scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                    rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
         | 
| 338 | 
            +
                    rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
         | 
| 339 | 
            +
                    rots = np.zeros((xyz.shape[0], len(rot_names)))
         | 
| 340 | 
            +
                    for idx, attr_name in enumerate(rot_names):
         | 
| 341 | 
            +
                        rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                    self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
         | 
| 344 | 
            +
                    self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
         | 
| 345 | 
            +
                    self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
         | 
| 346 | 
            +
                    self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
         | 
| 347 | 
            +
                    self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
         | 
| 348 | 
            +
                    self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                    self.active_sh_degree = self.max_sh_degree
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                def replace_tensor_to_optimizer(self, tensor, name):
         | 
| 353 | 
            +
                    optimizable_tensors = {}
         | 
| 354 | 
            +
                    for group in self.optimizer.param_groups:
         | 
| 355 | 
            +
                        if group["name"] == name:
         | 
| 356 | 
            +
                            # breakpoint()
         | 
| 357 | 
            +
                            stored_state = self.optimizer.state.get(group['params'][0], None)
         | 
| 358 | 
            +
                            stored_state["exp_avg"] = torch.zeros_like(tensor)
         | 
| 359 | 
            +
                            stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                            del self.optimizer.state[group['params'][0]]
         | 
| 362 | 
            +
                            group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
         | 
| 363 | 
            +
                            self.optimizer.state[group['params'][0]] = stored_state
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                            optimizable_tensors[group["name"]] = group["params"][0]
         | 
| 366 | 
            +
                    return optimizable_tensors
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                def _prune_optimizer(self, mask):
         | 
| 369 | 
            +
                    optimizable_tensors = {}
         | 
| 370 | 
            +
                    for group in self.optimizer.param_groups:
         | 
| 371 | 
            +
                        stored_state = self.optimizer.state.get(group['params'][0], None)
         | 
| 372 | 
            +
                        if stored_state is not None:
         | 
| 373 | 
            +
                            stored_state["exp_avg"] = stored_state["exp_avg"][mask]
         | 
| 374 | 
            +
                            stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                            del self.optimizer.state[group['params'][0]]
         | 
| 377 | 
            +
                            group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
         | 
| 378 | 
            +
                            self.optimizer.state[group['params'][0]] = stored_state
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                            optimizable_tensors[group["name"]] = group["params"][0]
         | 
| 381 | 
            +
                        else:
         | 
| 382 | 
            +
                            group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
         | 
| 383 | 
            +
                            optimizable_tensors[group["name"]] = group["params"][0]
         | 
| 384 | 
            +
                    return optimizable_tensors
         | 
| 385 | 
            +
             | 
| 386 | 
            +
                def prune_points(self, mask):
         | 
| 387 | 
            +
                    valid_points_mask = ~mask
         | 
| 388 | 
            +
                    optimizable_tensors = self._prune_optimizer(valid_points_mask)
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                    self._xyz = optimizable_tensors["xyz"]
         | 
| 391 | 
            +
                    self._features_dc = optimizable_tensors["f_dc"]
         | 
| 392 | 
            +
                    self._features_rest = optimizable_tensors["f_rest"]
         | 
| 393 | 
            +
                    self._opacity = optimizable_tensors["opacity"]
         | 
| 394 | 
            +
                    self._scaling = optimizable_tensors["scaling"]
         | 
| 395 | 
            +
                    self._rotation = optimizable_tensors["rotation"]
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                    self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                    self.denom = self.denom[valid_points_mask]
         | 
| 400 | 
            +
                    self.max_radii2D = self.max_radii2D[valid_points_mask]
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                def cat_tensors_to_optimizer(self, tensors_dict):
         | 
| 403 | 
            +
                    optimizable_tensors = {}
         | 
| 404 | 
            +
                    for group in self.optimizer.param_groups:
         | 
| 405 | 
            +
                        assert len(group["params"]) == 1
         | 
| 406 | 
            +
                        extension_tensor = tensors_dict[group["name"]]
         | 
| 407 | 
            +
                        stored_state = self.optimizer.state.get(group['params'][0], None)
         | 
| 408 | 
            +
                        if stored_state is not None:
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                            stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
         | 
| 411 | 
            +
                            stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                            del self.optimizer.state[group['params'][0]]
         | 
| 414 | 
            +
                            group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
         | 
| 415 | 
            +
                            self.optimizer.state[group['params'][0]] = stored_state
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                            optimizable_tensors[group["name"]] = group["params"][0]
         | 
| 418 | 
            +
                        else:
         | 
| 419 | 
            +
                            group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
         | 
| 420 | 
            +
                            optimizable_tensors[group["name"]] = group["params"][0]
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                    return optimizable_tensors
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation):
         | 
| 425 | 
            +
                    d = {"xyz": new_xyz,
         | 
| 426 | 
            +
                    "f_dc": new_features_dc,
         | 
| 427 | 
            +
                    "f_rest": new_features_rest,
         | 
| 428 | 
            +
                    "opacity": new_opacities,
         | 
| 429 | 
            +
                    "scaling" : new_scaling,
         | 
| 430 | 
            +
                    "rotation" : new_rotation}
         | 
| 431 | 
            +
             | 
| 432 | 
            +
                    optimizable_tensors = self.cat_tensors_to_optimizer(d)
         | 
| 433 | 
            +
                    self._xyz = optimizable_tensors["xyz"]
         | 
| 434 | 
            +
                    self._features_dc = optimizable_tensors["f_dc"]
         | 
| 435 | 
            +
                    self._features_rest = optimizable_tensors["f_rest"]
         | 
| 436 | 
            +
                    self._opacity = optimizable_tensors["opacity"]
         | 
| 437 | 
            +
                    self._scaling = optimizable_tensors["scaling"]
         | 
| 438 | 
            +
                    self._rotation = optimizable_tensors["rotation"]
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                    self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
         | 
| 441 | 
            +
                    self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
         | 
| 442 | 
            +
                    self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
         | 
| 445 | 
            +
                    n_init_points = self.get_xyz.shape[0]
         | 
| 446 | 
            +
                    # Extract points that satisfy the gradient condition
         | 
| 447 | 
            +
                    padded_grad = torch.zeros((n_init_points), device="cuda")
         | 
| 448 | 
            +
                    padded_grad[:grads.shape[0]] = grads.squeeze()
         | 
| 449 | 
            +
                    selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
         | 
| 450 | 
            +
                    selected_pts_mask = torch.logical_and(selected_pts_mask,
         | 
| 451 | 
            +
                                                          torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                    stds = self.get_scaling[selected_pts_mask].repeat(N,1)
         | 
| 454 | 
            +
                    means =torch.zeros((stds.size(0), 3),device="cuda")
         | 
| 455 | 
            +
                    samples = torch.normal(mean=means, std=stds)
         | 
| 456 | 
            +
                    rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
         | 
| 457 | 
            +
                    new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
         | 
| 458 | 
            +
                    new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
         | 
| 459 | 
            +
                    new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
         | 
| 460 | 
            +
                    new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
         | 
| 461 | 
            +
                    new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
         | 
| 462 | 
            +
                    new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                    self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                    prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
         | 
| 467 | 
            +
                    self.prune_points(prune_filter)
         | 
| 468 | 
            +
             | 
| 469 | 
            +
                def densify_and_clone(self, grads, grad_threshold, scene_extent):
         | 
| 470 | 
            +
                    # Extract points that satisfy the gradient condition
         | 
| 471 | 
            +
                    selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
         | 
| 472 | 
            +
                    selected_pts_mask = torch.logical_and(selected_pts_mask,
         | 
| 473 | 
            +
                                                          torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
         | 
| 474 | 
            +
                    
         | 
| 475 | 
            +
                    new_xyz = self._xyz[selected_pts_mask]
         | 
| 476 | 
            +
                    new_features_dc = self._features_dc[selected_pts_mask]
         | 
| 477 | 
            +
                    new_features_rest = self._features_rest[selected_pts_mask]
         | 
| 478 | 
            +
                    new_opacities = self._opacity[selected_pts_mask]
         | 
| 479 | 
            +
                    new_scaling = self._scaling[selected_pts_mask]
         | 
| 480 | 
            +
                    new_rotation = self._rotation[selected_pts_mask]
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                    self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation)
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
         | 
| 485 | 
            +
                    grads = self.xyz_gradient_accum / self.denom
         | 
| 486 | 
            +
                    grads[grads.isnan()] = 0.0
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                    # self.densify_and_clone(grads, max_grad, extent)
         | 
| 489 | 
            +
                    # self.densify_and_split(grads, max_grad, extent)
         | 
| 490 | 
            +
             | 
| 491 | 
            +
                    prune_mask = (self.get_opacity < min_opacity).squeeze()
         | 
| 492 | 
            +
                    if max_screen_size:
         | 
| 493 | 
            +
                        big_points_vs = self.max_radii2D > max_screen_size
         | 
| 494 | 
            +
                        big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
         | 
| 495 | 
            +
                        prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
         | 
| 496 | 
            +
                    self.prune_points(prune_mask)
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                    torch.cuda.empty_cache()
         | 
| 499 | 
            +
             | 
| 500 | 
            +
                def add_densification_stats(self, viewspace_point_tensor, update_filter):
         | 
| 501 | 
            +
                    self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
         | 
| 502 | 
            +
                    self.denom[update_filter] += 1
         | 
    	
        submodules/diff-gaussian-rasterization/.gitignore
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            build/
         | 
| 2 | 
            +
            diff_gaussian_rasterization.egg-info/
         | 
| 3 | 
            +
            dist/
         | 
    	
        submodules/diff-gaussian-rasterization/.gitmodules
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            [submodule "third_party/glm"]
         | 
| 2 | 
            +
            	path = third_party/glm
         | 
| 3 | 
            +
            	url = https://github.com/g-truc/glm.git
         | 
    	
        submodules/diff-gaussian-rasterization/CMakeLists.txt
    ADDED
    
    | @@ -0,0 +1,36 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            # Copyright (C) 2023, Inria
         | 
| 3 | 
            +
            # GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
            # All rights reserved.
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            # This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
            # under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # For inquiries contact  [email protected]
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            cmake_minimum_required(VERSION 3.20)
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            project(DiffRast LANGUAGES CUDA CXX)
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            set(CMAKE_CXX_STANDARD 17)
         | 
| 17 | 
            +
            set(CMAKE_CXX_EXTENSIONS OFF)
         | 
| 18 | 
            +
            set(CMAKE_CUDA_STANDARD 17)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            add_library(CudaRasterizer
         | 
| 23 | 
            +
            	cuda_rasterizer/backward.h
         | 
| 24 | 
            +
            	cuda_rasterizer/backward.cu
         | 
| 25 | 
            +
            	cuda_rasterizer/forward.h
         | 
| 26 | 
            +
            	cuda_rasterizer/forward.cu
         | 
| 27 | 
            +
            	cuda_rasterizer/auxiliary.h
         | 
| 28 | 
            +
            	cuda_rasterizer/rasterizer_impl.cu
         | 
| 29 | 
            +
            	cuda_rasterizer/rasterizer_impl.h
         | 
| 30 | 
            +
            	cuda_rasterizer/rasterizer.h
         | 
| 31 | 
            +
            )
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            set_target_properties(CudaRasterizer PROPERTIES CUDA_ARCHITECTURES "70;75;86")
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            target_include_directories(CudaRasterizer PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/cuda_rasterizer)
         | 
| 36 | 
            +
            target_include_directories(CudaRasterizer PRIVATE third_party/glm ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
         | 
    	
        submodules/diff-gaussian-rasterization/LICENSE.md
    ADDED
    
    | @@ -0,0 +1,83 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            Gaussian-Splatting License  
         | 
| 2 | 
            +
            ===========================  
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.  
         | 
| 5 | 
            +
            The *Software* is in the process of being registered with the Agence pour la Protection des  
         | 
| 6 | 
            +
            Programmes (APP).  
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            The *Software* is still being developed by the *Licensor*.  
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            *Licensor*'s goal is to allow the research community to use, test and evaluate  
         | 
| 11 | 
            +
            the *Software*.  
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            ## 1.  Definitions  
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            *Licensee* means any person or entity that uses the *Software* and distributes  
         | 
| 16 | 
            +
            its *Work*.  
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            *Licensor* means the owners of the *Software*, i.e Inria and MPII  
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            *Software* means the original work of authorship made available under this  
         | 
| 21 | 
            +
            License ie gaussian-splatting.  
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            *Work* means the *Software* and any additions to or derivative works of the  
         | 
| 24 | 
            +
            *Software* that are made available under this License.  
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            ## 2.  Purpose  
         | 
| 28 | 
            +
            This license is intended to define the rights granted to the *Licensee* by  
         | 
| 29 | 
            +
            Licensors under the *Software*.  
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            ## 3.  Rights granted  
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            For the above reasons Licensors have decided to distribute the *Software*.  
         | 
| 34 | 
            +
            Licensors grant non-exclusive rights to use the *Software* for research purposes  
         | 
| 35 | 
            +
            to research users (both academic and industrial), free of charge, without right  
         | 
| 36 | 
            +
            to sublicense.. The *Software* may be used "non-commercially", i.e., for research  
         | 
| 37 | 
            +
            and/or evaluation purposes only.  
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            Subject to the terms and conditions of this License, you are granted a  
         | 
| 40 | 
            +
            non-exclusive, royalty-free, license to reproduce, prepare derivative works of,  
         | 
| 41 | 
            +
            publicly display, publicly perform and distribute its *Work* and any resulting  
         | 
| 42 | 
            +
            derivative works in any form.  
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            ## 4.  Limitations  
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do  
         | 
| 47 | 
            +
            so under this License, (b) you include a complete copy of this License with  
         | 
| 48 | 
            +
            your distribution, and (c) you retain without modification any copyright,  
         | 
| 49 | 
            +
            patent, trademark, or attribution notices that are present in the *Work*.  
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            **4.2 Derivative Works.** You may specify that additional or different terms apply  
         | 
| 52 | 
            +
            to the use, reproduction, and distribution of your derivative works of the *Work*  
         | 
| 53 | 
            +
            ("Your Terms") only if (a) Your Terms provide that the use limitation in  
         | 
| 54 | 
            +
            Section 2 applies to your derivative works, and (b) you identify the specific  
         | 
| 55 | 
            +
            derivative works that are subject to Your Terms. Notwithstanding Your Terms,  
         | 
| 56 | 
            +
            this License (including the redistribution requirements in Section 3.1) will  
         | 
| 57 | 
            +
            continue to apply to the *Work* itself.  
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            **4.3** Any other use without of prior consent of Licensors is prohibited. Research  
         | 
| 60 | 
            +
            users explicitly acknowledge having received from Licensors all information  
         | 
| 61 | 
            +
            allowing to appreciate the adequacy between of the *Software* and their needs and  
         | 
| 62 | 
            +
            to undertake all necessary precautions for its execution and use.  
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            **4.4** The *Software* is provided both as a compiled library file and as source  
         | 
| 65 | 
            +
            code. In case of using the *Software* for a publication or other results obtained  
         | 
| 66 | 
            +
            through the use of the *Software*, users are strongly encouraged to cite the  
         | 
| 67 | 
            +
            corresponding publications as explained in the documentation of the *Software*.  
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            ## 5.  Disclaimer  
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES  
         | 
| 72 | 
            +
            WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY  
         | 
| 73 | 
            +
            UNAUTHORIZED USE: [email protected] . ANY SUCH ACTION WILL  
         | 
| 74 | 
            +
            CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES  
         | 
| 75 | 
            +
            OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL  
         | 
| 76 | 
            +
            USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR  
         | 
| 77 | 
            +
            ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE  
         | 
| 78 | 
            +
            AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR  
         | 
| 79 | 
            +
            CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE  
         | 
| 80 | 
            +
            GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)  
         | 
| 81 | 
            +
            HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT  
         | 
| 82 | 
            +
            LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR  
         | 
| 83 | 
            +
            IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.  
         | 
    	
        submodules/diff-gaussian-rasterization/README.md
    ADDED
    
    | @@ -0,0 +1,19 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Differential Gaussian Rasterization
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Used as the rasterization engine for the paper "3D Gaussian Splatting for Real-Time Rendering of Radiance Fields". If you can make use of it in your own research, please be so kind to cite us.
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            <section class="section" id="BibTeX">
         | 
| 6 | 
            +
              <div class="container is-max-desktop content">
         | 
| 7 | 
            +
                <h2 class="title">BibTeX</h2>
         | 
| 8 | 
            +
                <pre><code>@Article{kerbl3Dgaussians,
         | 
| 9 | 
            +
                  author       = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
         | 
| 10 | 
            +
                  title        = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
         | 
| 11 | 
            +
                  journal      = {ACM Transactions on Graphics},
         | 
| 12 | 
            +
                  number       = {4},
         | 
| 13 | 
            +
                  volume       = {42},
         | 
| 14 | 
            +
                  month        = {July},
         | 
| 15 | 
            +
                  year         = {2023},
         | 
| 16 | 
            +
                  url          = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
         | 
| 17 | 
            +
            }</code></pre>
         | 
| 18 | 
            +
              </div>
         | 
| 19 | 
            +
            </section>
         | 
    	
        submodules/diff-gaussian-rasterization/cuda_rasterizer/auxiliary.h
    ADDED
    
    | @@ -0,0 +1,175 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            /*
         | 
| 2 | 
            +
             * Copyright (C) 2023, Inria
         | 
| 3 | 
            +
             * GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
             * All rights reserved.
         | 
| 5 | 
            +
             *
         | 
| 6 | 
            +
             * This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
             * under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
             *
         | 
| 9 | 
            +
             * For inquiries contact  [email protected]
         | 
| 10 | 
            +
             */
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            #ifndef CUDA_RASTERIZER_AUXILIARY_H_INCLUDED
         | 
| 13 | 
            +
            #define CUDA_RASTERIZER_AUXILIARY_H_INCLUDED
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            #include "config.h"
         | 
| 16 | 
            +
            #include "stdio.h"
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            #define BLOCK_SIZE (BLOCK_X * BLOCK_Y)
         | 
| 19 | 
            +
            #define NUM_WARPS (BLOCK_SIZE/32)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            // Spherical harmonics coefficients
         | 
| 22 | 
            +
            __device__ const float SH_C0 = 0.28209479177387814f;
         | 
| 23 | 
            +
            __device__ const float SH_C1 = 0.4886025119029199f;
         | 
| 24 | 
            +
            __device__ const float SH_C2[] = {
         | 
| 25 | 
            +
            	1.0925484305920792f,
         | 
| 26 | 
            +
            	-1.0925484305920792f,
         | 
| 27 | 
            +
            	0.31539156525252005f,
         | 
| 28 | 
            +
            	-1.0925484305920792f,
         | 
| 29 | 
            +
            	0.5462742152960396f
         | 
| 30 | 
            +
            };
         | 
| 31 | 
            +
            __device__ const float SH_C3[] = {
         | 
| 32 | 
            +
            	-0.5900435899266435f,
         | 
| 33 | 
            +
            	2.890611442640554f,
         | 
| 34 | 
            +
            	-0.4570457994644658f,
         | 
| 35 | 
            +
            	0.3731763325901154f,
         | 
| 36 | 
            +
            	-0.4570457994644658f,
         | 
| 37 | 
            +
            	1.445305721320277f,
         | 
| 38 | 
            +
            	-0.5900435899266435f
         | 
| 39 | 
            +
            };
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            __forceinline__ __device__ float ndc2Pix(float v, int S)
         | 
| 42 | 
            +
            {
         | 
| 43 | 
            +
            	return ((v + 1.0) * S - 1.0) * 0.5;
         | 
| 44 | 
            +
            }
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            __forceinline__ __device__ void getRect(const float2 p, int max_radius, uint2& rect_min, uint2& rect_max, dim3 grid)
         | 
| 47 | 
            +
            {
         | 
| 48 | 
            +
            	rect_min = {
         | 
| 49 | 
            +
            		min(grid.x, max((int)0, (int)((p.x - max_radius) / BLOCK_X))),
         | 
| 50 | 
            +
            		min(grid.y, max((int)0, (int)((p.y - max_radius) / BLOCK_Y)))
         | 
| 51 | 
            +
            	};
         | 
| 52 | 
            +
            	rect_max = {
         | 
| 53 | 
            +
            		min(grid.x, max((int)0, (int)((p.x + max_radius + BLOCK_X - 1) / BLOCK_X))),
         | 
| 54 | 
            +
            		min(grid.y, max((int)0, (int)((p.y + max_radius + BLOCK_Y - 1) / BLOCK_Y)))
         | 
| 55 | 
            +
            	};
         | 
| 56 | 
            +
            }
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            __forceinline__ __device__ float3 transformPoint4x3(const float3& p, const float* matrix)
         | 
| 59 | 
            +
            {
         | 
| 60 | 
            +
            	float3 transformed = {
         | 
| 61 | 
            +
            		matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12],
         | 
| 62 | 
            +
            		matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13],
         | 
| 63 | 
            +
            		matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14],
         | 
| 64 | 
            +
            	};
         | 
| 65 | 
            +
            	return transformed;
         | 
| 66 | 
            +
            }
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            __forceinline__ __device__ float4 transformPoint4x4(const float3& p, const float* matrix)
         | 
| 69 | 
            +
            {
         | 
| 70 | 
            +
            	float4 transformed = {
         | 
| 71 | 
            +
            		matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z + matrix[12],
         | 
| 72 | 
            +
            		matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z + matrix[13],
         | 
| 73 | 
            +
            		matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z + matrix[14],
         | 
| 74 | 
            +
            		matrix[3] * p.x + matrix[7] * p.y + matrix[11] * p.z + matrix[15]
         | 
| 75 | 
            +
            	};
         | 
| 76 | 
            +
            	return transformed;
         | 
| 77 | 
            +
            }
         | 
| 78 | 
            +
             | 
| 79 | 
            +
            __forceinline__ __device__ float3 transformVec4x3(const float3& p, const float* matrix)
         | 
| 80 | 
            +
            {
         | 
| 81 | 
            +
            	float3 transformed = {
         | 
| 82 | 
            +
            		matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z,
         | 
| 83 | 
            +
            		matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z,
         | 
| 84 | 
            +
            		matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z,
         | 
| 85 | 
            +
            	};
         | 
| 86 | 
            +
            	return transformed;
         | 
| 87 | 
            +
            }
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            __forceinline__ __device__ float3 transformVec4x3Transpose(const float3& p, const float* matrix)
         | 
| 90 | 
            +
            {
         | 
| 91 | 
            +
            	float3 transformed = {
         | 
| 92 | 
            +
            		matrix[0] * p.x + matrix[1] * p.y + matrix[2] * p.z,
         | 
| 93 | 
            +
            		matrix[4] * p.x + matrix[5] * p.y + matrix[6] * p.z,
         | 
| 94 | 
            +
            		matrix[8] * p.x + matrix[9] * p.y + matrix[10] * p.z,
         | 
| 95 | 
            +
            	};
         | 
| 96 | 
            +
            	return transformed;
         | 
| 97 | 
            +
            }
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            __forceinline__ __device__ float dnormvdz(float3 v, float3 dv)
         | 
| 100 | 
            +
            {
         | 
| 101 | 
            +
            	float sum2 = v.x * v.x + v.y * v.y + v.z * v.z;
         | 
| 102 | 
            +
            	float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
         | 
| 103 | 
            +
            	float dnormvdz = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32;
         | 
| 104 | 
            +
            	return dnormvdz;
         | 
| 105 | 
            +
            }
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            __forceinline__ __device__ float3 dnormvdv(float3 v, float3 dv)
         | 
| 108 | 
            +
            {
         | 
| 109 | 
            +
            	float sum2 = v.x * v.x + v.y * v.y + v.z * v.z;
         | 
| 110 | 
            +
            	float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
         | 
| 111 | 
            +
             | 
| 112 | 
            +
            	float3 dnormvdv;
         | 
| 113 | 
            +
            	dnormvdv.x = ((+sum2 - v.x * v.x) * dv.x - v.y * v.x * dv.y - v.z * v.x * dv.z) * invsum32;
         | 
| 114 | 
            +
            	dnormvdv.y = (-v.x * v.y * dv.x + (sum2 - v.y * v.y) * dv.y - v.z * v.y * dv.z) * invsum32;
         | 
| 115 | 
            +
            	dnormvdv.z = (-v.x * v.z * dv.x - v.y * v.z * dv.y + (sum2 - v.z * v.z) * dv.z) * invsum32;
         | 
| 116 | 
            +
            	return dnormvdv;
         | 
| 117 | 
            +
            }
         | 
| 118 | 
            +
             | 
| 119 | 
            +
            __forceinline__ __device__ float4 dnormvdv(float4 v, float4 dv)
         | 
| 120 | 
            +
            {
         | 
| 121 | 
            +
            	float sum2 = v.x * v.x + v.y * v.y + v.z * v.z + v.w * v.w;
         | 
| 122 | 
            +
            	float invsum32 = 1.0f / sqrt(sum2 * sum2 * sum2);
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            	float4 vdv = { v.x * dv.x, v.y * dv.y, v.z * dv.z, v.w * dv.w };
         | 
| 125 | 
            +
            	float vdv_sum = vdv.x + vdv.y + vdv.z + vdv.w;
         | 
| 126 | 
            +
            	float4 dnormvdv;
         | 
| 127 | 
            +
            	dnormvdv.x = ((sum2 - v.x * v.x) * dv.x - v.x * (vdv_sum - vdv.x)) * invsum32;
         | 
| 128 | 
            +
            	dnormvdv.y = ((sum2 - v.y * v.y) * dv.y - v.y * (vdv_sum - vdv.y)) * invsum32;
         | 
| 129 | 
            +
            	dnormvdv.z = ((sum2 - v.z * v.z) * dv.z - v.z * (vdv_sum - vdv.z)) * invsum32;
         | 
| 130 | 
            +
            	dnormvdv.w = ((sum2 - v.w * v.w) * dv.w - v.w * (vdv_sum - vdv.w)) * invsum32;
         | 
| 131 | 
            +
            	return dnormvdv;
         | 
| 132 | 
            +
            }
         | 
| 133 | 
            +
             | 
| 134 | 
            +
            __forceinline__ __device__ float sigmoid(float x)
         | 
| 135 | 
            +
            {
         | 
| 136 | 
            +
            	return 1.0f / (1.0f + expf(-x));
         | 
| 137 | 
            +
            }
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            __forceinline__ __device__ bool in_frustum(int idx,
         | 
| 140 | 
            +
            	const float* orig_points,
         | 
| 141 | 
            +
            	const float* viewmatrix,
         | 
| 142 | 
            +
            	const float* projmatrix,
         | 
| 143 | 
            +
            	bool prefiltered,
         | 
| 144 | 
            +
            	float3& p_view)
         | 
| 145 | 
            +
            {
         | 
| 146 | 
            +
            	float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] };
         | 
| 147 | 
            +
             | 
| 148 | 
            +
            	// Bring points to screen space
         | 
| 149 | 
            +
            	float4 p_hom = transformPoint4x4(p_orig, projmatrix);
         | 
| 150 | 
            +
            	float p_w = 1.0f / (p_hom.w + 0.0000001f);
         | 
| 151 | 
            +
            	float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
         | 
| 152 | 
            +
            	p_view = transformPoint4x3(p_orig, viewmatrix);
         | 
| 153 | 
            +
             | 
| 154 | 
            +
            	if (p_view.z <= 0.01f)// || ((p_proj.x < -1.3 || p_proj.x > 1.3 || p_proj.y < -1.3 || p_proj.y > 1.3)))
         | 
| 155 | 
            +
            	{
         | 
| 156 | 
            +
            		if (prefiltered)
         | 
| 157 | 
            +
            		{
         | 
| 158 | 
            +
            			printf("Point is filtered although prefiltered is set. This shouldn't happen!");
         | 
| 159 | 
            +
            			__trap();
         | 
| 160 | 
            +
            		}
         | 
| 161 | 
            +
            		return false;
         | 
| 162 | 
            +
            	}
         | 
| 163 | 
            +
            	return true;
         | 
| 164 | 
            +
            }
         | 
| 165 | 
            +
             | 
| 166 | 
            +
            #define CHECK_CUDA(A, debug) \
         | 
| 167 | 
            +
            A; if(debug) { \
         | 
| 168 | 
            +
            auto ret = cudaDeviceSynchronize(); \
         | 
| 169 | 
            +
            if (ret != cudaSuccess) { \
         | 
| 170 | 
            +
            std::cerr << "\n[CUDA ERROR] in " << __FILE__ << "\nLine " << __LINE__ << ": " << cudaGetErrorString(ret); \
         | 
| 171 | 
            +
            throw std::runtime_error(cudaGetErrorString(ret)); \
         | 
| 172 | 
            +
            } \
         | 
| 173 | 
            +
            }
         | 
| 174 | 
            +
             | 
| 175 | 
            +
            #endif
         | 
    	
        submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.cu
    ADDED
    
    | @@ -0,0 +1,657 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            /*
         | 
| 2 | 
            +
             * Copyright (C) 2023, Inria
         | 
| 3 | 
            +
             * GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
             * All rights reserved.
         | 
| 5 | 
            +
             *
         | 
| 6 | 
            +
             * This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
             * under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
             *
         | 
| 9 | 
            +
             * For inquiries contact  [email protected]
         | 
| 10 | 
            +
             */
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            #include "backward.h"
         | 
| 13 | 
            +
            #include "auxiliary.h"
         | 
| 14 | 
            +
            #include <cooperative_groups.h>
         | 
| 15 | 
            +
            #include <cooperative_groups/reduce.h>
         | 
| 16 | 
            +
            namespace cg = cooperative_groups;
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            // Backward pass for conversion of spherical harmonics to RGB for
         | 
| 19 | 
            +
            // each Gaussian.
         | 
| 20 | 
            +
            __device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, const bool* clamped, const glm::vec3* dL_dcolor, glm::vec3* dL_dmeans, glm::vec3* dL_dshs)
         | 
| 21 | 
            +
            {
         | 
| 22 | 
            +
            	// Compute intermediate values, as it is done during forward
         | 
| 23 | 
            +
            	glm::vec3 pos = means[idx];
         | 
| 24 | 
            +
            	glm::vec3 dir_orig = pos - campos;
         | 
| 25 | 
            +
            	glm::vec3 dir = dir_orig / glm::length(dir_orig);
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            	glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs;
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            	// Use PyTorch rule for clamping: if clamping was applied,
         | 
| 30 | 
            +
            	// gradient becomes 0.
         | 
| 31 | 
            +
            	glm::vec3 dL_dRGB = dL_dcolor[idx];
         | 
| 32 | 
            +
            	dL_dRGB.x *= clamped[3 * idx + 0] ? 0 : 1;
         | 
| 33 | 
            +
            	dL_dRGB.y *= clamped[3 * idx + 1] ? 0 : 1;
         | 
| 34 | 
            +
            	dL_dRGB.z *= clamped[3 * idx + 2] ? 0 : 1;
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            	glm::vec3 dRGBdx(0, 0, 0);
         | 
| 37 | 
            +
            	glm::vec3 dRGBdy(0, 0, 0);
         | 
| 38 | 
            +
            	glm::vec3 dRGBdz(0, 0, 0);
         | 
| 39 | 
            +
            	float x = dir.x;
         | 
| 40 | 
            +
            	float y = dir.y;
         | 
| 41 | 
            +
            	float z = dir.z;
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            	// Target location for this Gaussian to write SH gradients to
         | 
| 44 | 
            +
            	glm::vec3* dL_dsh = dL_dshs + idx * max_coeffs;
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            	// No tricks here, just high school-level calculus.
         | 
| 47 | 
            +
            	float dRGBdsh0 = SH_C0;
         | 
| 48 | 
            +
            	dL_dsh[0] = dRGBdsh0 * dL_dRGB;
         | 
| 49 | 
            +
            	if (deg > 0)
         | 
| 50 | 
            +
            	{
         | 
| 51 | 
            +
            		float dRGBdsh1 = -SH_C1 * y;
         | 
| 52 | 
            +
            		float dRGBdsh2 = SH_C1 * z;
         | 
| 53 | 
            +
            		float dRGBdsh3 = -SH_C1 * x;
         | 
| 54 | 
            +
            		dL_dsh[1] = dRGBdsh1 * dL_dRGB;
         | 
| 55 | 
            +
            		dL_dsh[2] = dRGBdsh2 * dL_dRGB;
         | 
| 56 | 
            +
            		dL_dsh[3] = dRGBdsh3 * dL_dRGB;
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            		dRGBdx = -SH_C1 * sh[3];
         | 
| 59 | 
            +
            		dRGBdy = -SH_C1 * sh[1];
         | 
| 60 | 
            +
            		dRGBdz = SH_C1 * sh[2];
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            		if (deg > 1)
         | 
| 63 | 
            +
            		{
         | 
| 64 | 
            +
            			float xx = x * x, yy = y * y, zz = z * z;
         | 
| 65 | 
            +
            			float xy = x * y, yz = y * z, xz = x * z;
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            			float dRGBdsh4 = SH_C2[0] * xy;
         | 
| 68 | 
            +
            			float dRGBdsh5 = SH_C2[1] * yz;
         | 
| 69 | 
            +
            			float dRGBdsh6 = SH_C2[2] * (2.f * zz - xx - yy);
         | 
| 70 | 
            +
            			float dRGBdsh7 = SH_C2[3] * xz;
         | 
| 71 | 
            +
            			float dRGBdsh8 = SH_C2[4] * (xx - yy);
         | 
| 72 | 
            +
            			dL_dsh[4] = dRGBdsh4 * dL_dRGB;
         | 
| 73 | 
            +
            			dL_dsh[5] = dRGBdsh5 * dL_dRGB;
         | 
| 74 | 
            +
            			dL_dsh[6] = dRGBdsh6 * dL_dRGB;
         | 
| 75 | 
            +
            			dL_dsh[7] = dRGBdsh7 * dL_dRGB;
         | 
| 76 | 
            +
            			dL_dsh[8] = dRGBdsh8 * dL_dRGB;
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            			dRGBdx += SH_C2[0] * y * sh[4] + SH_C2[2] * 2.f * -x * sh[6] + SH_C2[3] * z * sh[7] + SH_C2[4] * 2.f * x * sh[8];
         | 
| 79 | 
            +
            			dRGBdy += SH_C2[0] * x * sh[4] + SH_C2[1] * z * sh[5] + SH_C2[2] * 2.f * -y * sh[6] + SH_C2[4] * 2.f * -y * sh[8];
         | 
| 80 | 
            +
            			dRGBdz += SH_C2[1] * y * sh[5] + SH_C2[2] * 2.f * 2.f * z * sh[6] + SH_C2[3] * x * sh[7];
         | 
| 81 | 
            +
             | 
| 82 | 
            +
            			if (deg > 2)
         | 
| 83 | 
            +
            			{
         | 
| 84 | 
            +
            				float dRGBdsh9 = SH_C3[0] * y * (3.f * xx - yy);
         | 
| 85 | 
            +
            				float dRGBdsh10 = SH_C3[1] * xy * z;
         | 
| 86 | 
            +
            				float dRGBdsh11 = SH_C3[2] * y * (4.f * zz - xx - yy);
         | 
| 87 | 
            +
            				float dRGBdsh12 = SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy);
         | 
| 88 | 
            +
            				float dRGBdsh13 = SH_C3[4] * x * (4.f * zz - xx - yy);
         | 
| 89 | 
            +
            				float dRGBdsh14 = SH_C3[5] * z * (xx - yy);
         | 
| 90 | 
            +
            				float dRGBdsh15 = SH_C3[6] * x * (xx - 3.f * yy);
         | 
| 91 | 
            +
            				dL_dsh[9] = dRGBdsh9 * dL_dRGB;
         | 
| 92 | 
            +
            				dL_dsh[10] = dRGBdsh10 * dL_dRGB;
         | 
| 93 | 
            +
            				dL_dsh[11] = dRGBdsh11 * dL_dRGB;
         | 
| 94 | 
            +
            				dL_dsh[12] = dRGBdsh12 * dL_dRGB;
         | 
| 95 | 
            +
            				dL_dsh[13] = dRGBdsh13 * dL_dRGB;
         | 
| 96 | 
            +
            				dL_dsh[14] = dRGBdsh14 * dL_dRGB;
         | 
| 97 | 
            +
            				dL_dsh[15] = dRGBdsh15 * dL_dRGB;
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            				dRGBdx += (
         | 
| 100 | 
            +
            					SH_C3[0] * sh[9] * 3.f * 2.f * xy +
         | 
| 101 | 
            +
            					SH_C3[1] * sh[10] * yz +
         | 
| 102 | 
            +
            					SH_C3[2] * sh[11] * -2.f * xy +
         | 
| 103 | 
            +
            					SH_C3[3] * sh[12] * -3.f * 2.f * xz +
         | 
| 104 | 
            +
            					SH_C3[4] * sh[13] * (-3.f * xx + 4.f * zz - yy) +
         | 
| 105 | 
            +
            					SH_C3[5] * sh[14] * 2.f * xz +
         | 
| 106 | 
            +
            					SH_C3[6] * sh[15] * 3.f * (xx - yy));
         | 
| 107 | 
            +
             | 
| 108 | 
            +
            				dRGBdy += (
         | 
| 109 | 
            +
            					SH_C3[0] * sh[9] * 3.f * (xx - yy) +
         | 
| 110 | 
            +
            					SH_C3[1] * sh[10] * xz +
         | 
| 111 | 
            +
            					SH_C3[2] * sh[11] * (-3.f * yy + 4.f * zz - xx) +
         | 
| 112 | 
            +
            					SH_C3[3] * sh[12] * -3.f * 2.f * yz +
         | 
| 113 | 
            +
            					SH_C3[4] * sh[13] * -2.f * xy +
         | 
| 114 | 
            +
            					SH_C3[5] * sh[14] * -2.f * yz +
         | 
| 115 | 
            +
            					SH_C3[6] * sh[15] * -3.f * 2.f * xy);
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            				dRGBdz += (
         | 
| 118 | 
            +
            					SH_C3[1] * sh[10] * xy +
         | 
| 119 | 
            +
            					SH_C3[2] * sh[11] * 4.f * 2.f * yz +
         | 
| 120 | 
            +
            					SH_C3[3] * sh[12] * 3.f * (2.f * zz - xx - yy) +
         | 
| 121 | 
            +
            					SH_C3[4] * sh[13] * 4.f * 2.f * xz +
         | 
| 122 | 
            +
            					SH_C3[5] * sh[14] * (xx - yy));
         | 
| 123 | 
            +
            			}
         | 
| 124 | 
            +
            		}
         | 
| 125 | 
            +
            	}
         | 
| 126 | 
            +
             | 
| 127 | 
            +
            	// The view direction is an input to the computation. View direction
         | 
| 128 | 
            +
            	// is influenced by the Gaussian's mean, so SHs gradients
         | 
| 129 | 
            +
            	// must propagate back into 3D position.
         | 
| 130 | 
            +
            	glm::vec3 dL_ddir(glm::dot(dRGBdx, dL_dRGB), glm::dot(dRGBdy, dL_dRGB), glm::dot(dRGBdz, dL_dRGB));
         | 
| 131 | 
            +
             | 
| 132 | 
            +
            	// Account for normalization of direction
         | 
| 133 | 
            +
            	float3 dL_dmean = dnormvdv(float3{ dir_orig.x, dir_orig.y, dir_orig.z }, float3{ dL_ddir.x, dL_ddir.y, dL_ddir.z });
         | 
| 134 | 
            +
             | 
| 135 | 
            +
            	// Gradients of loss w.r.t. Gaussian means, but only the portion 
         | 
| 136 | 
            +
            	// that is caused because the mean affects the view-dependent color.
         | 
| 137 | 
            +
            	// Additional mean gradient is accumulated in below methods.
         | 
| 138 | 
            +
            	dL_dmeans[idx] += glm::vec3(dL_dmean.x, dL_dmean.y, dL_dmean.z);
         | 
| 139 | 
            +
            }
         | 
| 140 | 
            +
             | 
| 141 | 
            +
            // Backward version of INVERSE 2D covariance matrix computation
         | 
| 142 | 
            +
            // (due to length launched as separate kernel before other 
         | 
| 143 | 
            +
            // backward steps contained in preprocess)
         | 
| 144 | 
            +
            __global__ void computeCov2DCUDA(int P,
         | 
| 145 | 
            +
            	const float3* means,
         | 
| 146 | 
            +
            	const int* radii,
         | 
| 147 | 
            +
            	const float* cov3Ds,
         | 
| 148 | 
            +
            	const float h_x, float h_y,
         | 
| 149 | 
            +
            	const float tan_fovx, float tan_fovy,
         | 
| 150 | 
            +
            	const float* view_matrix,
         | 
| 151 | 
            +
            	const float* dL_dconics,
         | 
| 152 | 
            +
            	float3* dL_dmeans,
         | 
| 153 | 
            +
            	float* dL_dcov)
         | 
| 154 | 
            +
            {
         | 
| 155 | 
            +
            	auto idx = cg::this_grid().thread_rank();
         | 
| 156 | 
            +
            	if (idx >= P || !(radii[idx] > 0))
         | 
| 157 | 
            +
            		return;
         | 
| 158 | 
            +
             | 
| 159 | 
            +
            	// Reading location of 3D covariance for this Gaussian
         | 
| 160 | 
            +
            	const float* cov3D = cov3Ds + 6 * idx;
         | 
| 161 | 
            +
             | 
| 162 | 
            +
            	// Fetch gradients, recompute 2D covariance and relevant 
         | 
| 163 | 
            +
            	// intermediate forward results needed in the backward.
         | 
| 164 | 
            +
            	float3 mean = means[idx];
         | 
| 165 | 
            +
            	float3 dL_dconic = { dL_dconics[4 * idx], dL_dconics[4 * idx + 1], dL_dconics[4 * idx + 3] };
         | 
| 166 | 
            +
            	float3 t = transformPoint4x3(mean, view_matrix);
         | 
| 167 | 
            +
            	
         | 
| 168 | 
            +
            	const float limx = 1.3f * tan_fovx;
         | 
| 169 | 
            +
            	const float limy = 1.3f * tan_fovy;
         | 
| 170 | 
            +
            	const float txtz = t.x / t.z;
         | 
| 171 | 
            +
            	const float tytz = t.y / t.z;
         | 
| 172 | 
            +
            	t.x = min(limx, max(-limx, txtz)) * t.z;
         | 
| 173 | 
            +
            	t.y = min(limy, max(-limy, tytz)) * t.z;
         | 
| 174 | 
            +
            	
         | 
| 175 | 
            +
            	const float x_grad_mul = txtz < -limx || txtz > limx ? 0 : 1;
         | 
| 176 | 
            +
            	const float y_grad_mul = tytz < -limy || tytz > limy ? 0 : 1;
         | 
| 177 | 
            +
             | 
| 178 | 
            +
            	glm::mat3 J = glm::mat3(h_x / t.z, 0.0f, -(h_x * t.x) / (t.z * t.z),
         | 
| 179 | 
            +
            		0.0f, h_y / t.z, -(h_y * t.y) / (t.z * t.z),
         | 
| 180 | 
            +
            		0, 0, 0);
         | 
| 181 | 
            +
             | 
| 182 | 
            +
            	glm::mat3 W = glm::mat3(
         | 
| 183 | 
            +
            		view_matrix[0], view_matrix[4], view_matrix[8],
         | 
| 184 | 
            +
            		view_matrix[1], view_matrix[5], view_matrix[9],
         | 
| 185 | 
            +
            		view_matrix[2], view_matrix[6], view_matrix[10]);
         | 
| 186 | 
            +
             | 
| 187 | 
            +
            	glm::mat3 Vrk = glm::mat3(
         | 
| 188 | 
            +
            		cov3D[0], cov3D[1], cov3D[2],
         | 
| 189 | 
            +
            		cov3D[1], cov3D[3], cov3D[4],
         | 
| 190 | 
            +
            		cov3D[2], cov3D[4], cov3D[5]);
         | 
| 191 | 
            +
             | 
| 192 | 
            +
            	glm::mat3 T = W * J;
         | 
| 193 | 
            +
             | 
| 194 | 
            +
            	glm::mat3 cov2D = glm::transpose(T) * glm::transpose(Vrk) * T;
         | 
| 195 | 
            +
             | 
| 196 | 
            +
            	// Use helper variables for 2D covariance entries. More compact.
         | 
| 197 | 
            +
            	float a = cov2D[0][0] += 0.3f;
         | 
| 198 | 
            +
            	float b = cov2D[0][1];
         | 
| 199 | 
            +
            	float c = cov2D[1][1] += 0.3f;
         | 
| 200 | 
            +
             | 
| 201 | 
            +
            	float denom = a * c - b * b;
         | 
| 202 | 
            +
            	float dL_da = 0, dL_db = 0, dL_dc = 0;
         | 
| 203 | 
            +
            	float denom2inv = 1.0f / ((denom * denom) + 0.0000001f);
         | 
| 204 | 
            +
             | 
| 205 | 
            +
            	if (denom2inv != 0)
         | 
| 206 | 
            +
            	{
         | 
| 207 | 
            +
            		// Gradients of loss w.r.t. entries of 2D covariance matrix,
         | 
| 208 | 
            +
            		// given gradients of loss w.r.t. conic matrix (inverse covariance matrix).
         | 
| 209 | 
            +
            		// e.g., dL / da = dL / d_conic_a * d_conic_a / d_a
         | 
| 210 | 
            +
            		dL_da = denom2inv * (-c * c * dL_dconic.x + 2 * b * c * dL_dconic.y + (denom - a * c) * dL_dconic.z);
         | 
| 211 | 
            +
            		dL_dc = denom2inv * (-a * a * dL_dconic.z + 2 * a * b * dL_dconic.y + (denom - a * c) * dL_dconic.x);
         | 
| 212 | 
            +
            		dL_db = denom2inv * 2 * (b * c * dL_dconic.x - (denom + 2 * b * b) * dL_dconic.y + a * b * dL_dconic.z);
         | 
| 213 | 
            +
             | 
| 214 | 
            +
            		// Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry, 
         | 
| 215 | 
            +
            		// given gradients w.r.t. 2D covariance matrix (diagonal).
         | 
| 216 | 
            +
            		// cov2D = transpose(T) * transpose(Vrk) * T;
         | 
| 217 | 
            +
            		dL_dcov[6 * idx + 0] = (T[0][0] * T[0][0] * dL_da + T[0][0] * T[1][0] * dL_db + T[1][0] * T[1][0] * dL_dc);
         | 
| 218 | 
            +
            		dL_dcov[6 * idx + 3] = (T[0][1] * T[0][1] * dL_da + T[0][1] * T[1][1] * dL_db + T[1][1] * T[1][1] * dL_dc);
         | 
| 219 | 
            +
            		dL_dcov[6 * idx + 5] = (T[0][2] * T[0][2] * dL_da + T[0][2] * T[1][2] * dL_db + T[1][2] * T[1][2] * dL_dc);
         | 
| 220 | 
            +
             | 
| 221 | 
            +
            		// Gradients of loss L w.r.t. each 3D covariance matrix (Vrk) entry, 
         | 
| 222 | 
            +
            		// given gradients w.r.t. 2D covariance matrix (off-diagonal).
         | 
| 223 | 
            +
            		// Off-diagonal elements appear twice --> double the gradient.
         | 
| 224 | 
            +
            		// cov2D = transpose(T) * transpose(Vrk) * T;
         | 
| 225 | 
            +
            		dL_dcov[6 * idx + 1] = 2 * T[0][0] * T[0][1] * dL_da + (T[0][0] * T[1][1] + T[0][1] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][1] * dL_dc;
         | 
| 226 | 
            +
            		dL_dcov[6 * idx + 2] = 2 * T[0][0] * T[0][2] * dL_da + (T[0][0] * T[1][2] + T[0][2] * T[1][0]) * dL_db + 2 * T[1][0] * T[1][2] * dL_dc;
         | 
| 227 | 
            +
            		dL_dcov[6 * idx + 4] = 2 * T[0][2] * T[0][1] * dL_da + (T[0][1] * T[1][2] + T[0][2] * T[1][1]) * dL_db + 2 * T[1][1] * T[1][2] * dL_dc;
         | 
| 228 | 
            +
            	}
         | 
| 229 | 
            +
            	else
         | 
| 230 | 
            +
            	{
         | 
| 231 | 
            +
            		for (int i = 0; i < 6; i++)
         | 
| 232 | 
            +
            			dL_dcov[6 * idx + i] = 0;
         | 
| 233 | 
            +
            	}
         | 
| 234 | 
            +
             | 
| 235 | 
            +
            	// Gradients of loss w.r.t. upper 2x3 portion of intermediate matrix T
         | 
| 236 | 
            +
            	// cov2D = transpose(T) * transpose(Vrk) * T;
         | 
| 237 | 
            +
            	float dL_dT00 = 2 * (T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_da +
         | 
| 238 | 
            +
            		(T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_db;
         | 
| 239 | 
            +
            	float dL_dT01 = 2 * (T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_da +
         | 
| 240 | 
            +
            		(T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_db;
         | 
| 241 | 
            +
            	float dL_dT02 = 2 * (T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_da +
         | 
| 242 | 
            +
            		(T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_db;
         | 
| 243 | 
            +
            	float dL_dT10 = 2 * (T[1][0] * Vrk[0][0] + T[1][1] * Vrk[0][1] + T[1][2] * Vrk[0][2]) * dL_dc +
         | 
| 244 | 
            +
            		(T[0][0] * Vrk[0][0] + T[0][1] * Vrk[0][1] + T[0][2] * Vrk[0][2]) * dL_db;
         | 
| 245 | 
            +
            	float dL_dT11 = 2 * (T[1][0] * Vrk[1][0] + T[1][1] * Vrk[1][1] + T[1][2] * Vrk[1][2]) * dL_dc +
         | 
| 246 | 
            +
            		(T[0][0] * Vrk[1][0] + T[0][1] * Vrk[1][1] + T[0][2] * Vrk[1][2]) * dL_db;
         | 
| 247 | 
            +
            	float dL_dT12 = 2 * (T[1][0] * Vrk[2][0] + T[1][1] * Vrk[2][1] + T[1][2] * Vrk[2][2]) * dL_dc +
         | 
| 248 | 
            +
            		(T[0][0] * Vrk[2][0] + T[0][1] * Vrk[2][1] + T[0][2] * Vrk[2][2]) * dL_db;
         | 
| 249 | 
            +
             | 
| 250 | 
            +
            	// Gradients of loss w.r.t. upper 3x2 non-zero entries of Jacobian matrix
         | 
| 251 | 
            +
            	// T = W * J
         | 
| 252 | 
            +
            	float dL_dJ00 = W[0][0] * dL_dT00 + W[0][1] * dL_dT01 + W[0][2] * dL_dT02;
         | 
| 253 | 
            +
            	float dL_dJ02 = W[2][0] * dL_dT00 + W[2][1] * dL_dT01 + W[2][2] * dL_dT02;
         | 
| 254 | 
            +
            	float dL_dJ11 = W[1][0] * dL_dT10 + W[1][1] * dL_dT11 + W[1][2] * dL_dT12;
         | 
| 255 | 
            +
            	float dL_dJ12 = W[2][0] * dL_dT10 + W[2][1] * dL_dT11 + W[2][2] * dL_dT12;
         | 
| 256 | 
            +
             | 
| 257 | 
            +
            	float tz = 1.f / t.z;
         | 
| 258 | 
            +
            	float tz2 = tz * tz;
         | 
| 259 | 
            +
            	float tz3 = tz2 * tz;
         | 
| 260 | 
            +
             | 
| 261 | 
            +
            	// Gradients of loss w.r.t. transformed Gaussian mean t
         | 
| 262 | 
            +
            	float dL_dtx = x_grad_mul * -h_x * tz2 * dL_dJ02;
         | 
| 263 | 
            +
            	float dL_dty = y_grad_mul * -h_y * tz2 * dL_dJ12;
         | 
| 264 | 
            +
            	float dL_dtz = -h_x * tz2 * dL_dJ00 - h_y * tz2 * dL_dJ11 + (2 * h_x * t.x) * tz3 * dL_dJ02 + (2 * h_y * t.y) * tz3 * dL_dJ12;
         | 
| 265 | 
            +
             | 
| 266 | 
            +
            	// Account for transformation of mean to t
         | 
| 267 | 
            +
            	// t = transformPoint4x3(mean, view_matrix);
         | 
| 268 | 
            +
            	float3 dL_dmean = transformVec4x3Transpose({ dL_dtx, dL_dty, dL_dtz }, view_matrix);
         | 
| 269 | 
            +
             | 
| 270 | 
            +
            	// Gradients of loss w.r.t. Gaussian means, but only the portion 
         | 
| 271 | 
            +
            	// that is caused because the mean affects the covariance matrix.
         | 
| 272 | 
            +
            	// Additional mean gradient is accumulated in BACKWARD::preprocess.
         | 
| 273 | 
            +
            	dL_dmeans[idx] = dL_dmean;
         | 
| 274 | 
            +
            }
         | 
| 275 | 
            +
             | 
| 276 | 
            +
            // Backward pass for the conversion of scale and rotation to a 
         | 
| 277 | 
            +
            // 3D covariance matrix for each Gaussian. 
         | 
| 278 | 
            +
            __device__ void computeCov3D(int idx, const glm::vec3 scale, float mod, const glm::vec4 rot, const float* dL_dcov3Ds, glm::vec3* dL_dscales, glm::vec4* dL_drots)
         | 
| 279 | 
            +
            {
         | 
| 280 | 
            +
            	// Recompute (intermediate) results for the 3D covariance computation.
         | 
| 281 | 
            +
            	glm::vec4 q = rot;// / glm::length(rot);
         | 
| 282 | 
            +
            	float r = q.x;
         | 
| 283 | 
            +
            	float x = q.y;
         | 
| 284 | 
            +
            	float y = q.z;
         | 
| 285 | 
            +
            	float z = q.w;
         | 
| 286 | 
            +
             | 
| 287 | 
            +
            	glm::mat3 R = glm::mat3(
         | 
| 288 | 
            +
            		1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z), 2.f * (x * z + r * y),
         | 
| 289 | 
            +
            		2.f * (x * y + r * z), 1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x),
         | 
| 290 | 
            +
            		2.f * (x * z - r * y), 2.f * (y * z + r * x), 1.f - 2.f * (x * x + y * y)
         | 
| 291 | 
            +
            	);
         | 
| 292 | 
            +
             | 
| 293 | 
            +
            	glm::mat3 S = glm::mat3(1.0f);
         | 
| 294 | 
            +
             | 
| 295 | 
            +
            	glm::vec3 s = mod * scale;
         | 
| 296 | 
            +
            	S[0][0] = s.x;
         | 
| 297 | 
            +
            	S[1][1] = s.y;
         | 
| 298 | 
            +
            	S[2][2] = s.z;
         | 
| 299 | 
            +
             | 
| 300 | 
            +
            	glm::mat3 M = S * R;
         | 
| 301 | 
            +
             | 
| 302 | 
            +
            	const float* dL_dcov3D = dL_dcov3Ds + 6 * idx;
         | 
| 303 | 
            +
             | 
| 304 | 
            +
            	glm::vec3 dunc(dL_dcov3D[0], dL_dcov3D[3], dL_dcov3D[5]);
         | 
| 305 | 
            +
            	glm::vec3 ounc = 0.5f * glm::vec3(dL_dcov3D[1], dL_dcov3D[2], dL_dcov3D[4]);
         | 
| 306 | 
            +
             | 
| 307 | 
            +
            	// Convert per-element covariance loss gradients to matrix form
         | 
| 308 | 
            +
            	glm::mat3 dL_dSigma = glm::mat3(
         | 
| 309 | 
            +
            		dL_dcov3D[0], 0.5f * dL_dcov3D[1], 0.5f * dL_dcov3D[2],
         | 
| 310 | 
            +
            		0.5f * dL_dcov3D[1], dL_dcov3D[3], 0.5f * dL_dcov3D[4],
         | 
| 311 | 
            +
            		0.5f * dL_dcov3D[2], 0.5f * dL_dcov3D[4], dL_dcov3D[5]
         | 
| 312 | 
            +
            	);
         | 
| 313 | 
            +
             | 
| 314 | 
            +
            	// Compute loss gradient w.r.t. matrix M
         | 
| 315 | 
            +
            	// dSigma_dM = 2 * M
         | 
| 316 | 
            +
            	glm::mat3 dL_dM = 2.0f * M * dL_dSigma;
         | 
| 317 | 
            +
             | 
| 318 | 
            +
            	glm::mat3 Rt = glm::transpose(R);
         | 
| 319 | 
            +
            	glm::mat3 dL_dMt = glm::transpose(dL_dM);
         | 
| 320 | 
            +
             | 
| 321 | 
            +
            	// Gradients of loss w.r.t. scale
         | 
| 322 | 
            +
            	glm::vec3* dL_dscale = dL_dscales + idx;
         | 
| 323 | 
            +
            	dL_dscale->x = glm::dot(Rt[0], dL_dMt[0]);
         | 
| 324 | 
            +
            	dL_dscale->y = glm::dot(Rt[1], dL_dMt[1]);
         | 
| 325 | 
            +
            	dL_dscale->z = glm::dot(Rt[2], dL_dMt[2]);
         | 
| 326 | 
            +
             | 
| 327 | 
            +
            	dL_dMt[0] *= s.x;
         | 
| 328 | 
            +
            	dL_dMt[1] *= s.y;
         | 
| 329 | 
            +
            	dL_dMt[2] *= s.z;
         | 
| 330 | 
            +
             | 
| 331 | 
            +
            	// Gradients of loss w.r.t. normalized quaternion
         | 
| 332 | 
            +
            	glm::vec4 dL_dq;
         | 
| 333 | 
            +
            	dL_dq.x = 2 * z * (dL_dMt[0][1] - dL_dMt[1][0]) + 2 * y * (dL_dMt[2][0] - dL_dMt[0][2]) + 2 * x * (dL_dMt[1][2] - dL_dMt[2][1]);
         | 
| 334 | 
            +
            	dL_dq.y = 2 * y * (dL_dMt[1][0] + dL_dMt[0][1]) + 2 * z * (dL_dMt[2][0] + dL_dMt[0][2]) + 2 * r * (dL_dMt[1][2] - dL_dMt[2][1]) - 4 * x * (dL_dMt[2][2] + dL_dMt[1][1]);
         | 
| 335 | 
            +
            	dL_dq.z = 2 * x * (dL_dMt[1][0] + dL_dMt[0][1]) + 2 * r * (dL_dMt[2][0] - dL_dMt[0][2]) + 2 * z * (dL_dMt[1][2] + dL_dMt[2][1]) - 4 * y * (dL_dMt[2][2] + dL_dMt[0][0]);
         | 
| 336 | 
            +
            	dL_dq.w = 2 * r * (dL_dMt[0][1] - dL_dMt[1][0]) + 2 * x * (dL_dMt[2][0] + dL_dMt[0][2]) + 2 * y * (dL_dMt[1][2] + dL_dMt[2][1]) - 4 * z * (dL_dMt[1][1] + dL_dMt[0][0]);
         | 
| 337 | 
            +
             | 
| 338 | 
            +
            	// Gradients of loss w.r.t. unnormalized quaternion
         | 
| 339 | 
            +
            	float4* dL_drot = (float4*)(dL_drots + idx);
         | 
| 340 | 
            +
            	*dL_drot = float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w };//dnormvdv(float4{ rot.x, rot.y, rot.z, rot.w }, float4{ dL_dq.x, dL_dq.y, dL_dq.z, dL_dq.w });
         | 
| 341 | 
            +
            }
         | 
| 342 | 
            +
             | 
| 343 | 
            +
            // Backward pass of the preprocessing steps, except
         | 
| 344 | 
            +
            // for the covariance computation and inversion
         | 
| 345 | 
            +
            // (those are handled by a previous kernel call)
         | 
| 346 | 
            +
            template<int C>
         | 
| 347 | 
            +
            __global__ void preprocessCUDA(
         | 
| 348 | 
            +
            	int P, int D, int M,
         | 
| 349 | 
            +
            	const float3* means,
         | 
| 350 | 
            +
            	const int* radii,
         | 
| 351 | 
            +
            	const float* shs,
         | 
| 352 | 
            +
            	const bool* clamped,
         | 
| 353 | 
            +
            	const glm::vec3* scales,
         | 
| 354 | 
            +
            	const glm::vec4* rotations,
         | 
| 355 | 
            +
            	const float scale_modifier,
         | 
| 356 | 
            +
            	const float* proj,
         | 
| 357 | 
            +
            	const glm::vec3* campos,
         | 
| 358 | 
            +
            	const float3* dL_dmean2D,
         | 
| 359 | 
            +
            	glm::vec3* dL_dmeans,
         | 
| 360 | 
            +
            	float* dL_dcolor,
         | 
| 361 | 
            +
            	float* dL_dcov3D,
         | 
| 362 | 
            +
            	float* dL_dsh,
         | 
| 363 | 
            +
            	glm::vec3* dL_dscale,
         | 
| 364 | 
            +
            	glm::vec4* dL_drot)
         | 
| 365 | 
            +
            {
         | 
| 366 | 
            +
            	auto idx = cg::this_grid().thread_rank();
         | 
| 367 | 
            +
            	if (idx >= P || !(radii[idx] > 0))
         | 
| 368 | 
            +
            		return;
         | 
| 369 | 
            +
             | 
| 370 | 
            +
            	float3 m = means[idx];
         | 
| 371 | 
            +
             | 
| 372 | 
            +
            	// Taking care of gradients from the screenspace points
         | 
| 373 | 
            +
            	float4 m_hom = transformPoint4x4(m, proj);
         | 
| 374 | 
            +
            	float m_w = 1.0f / (m_hom.w + 0.0000001f);
         | 
| 375 | 
            +
             | 
| 376 | 
            +
            	// Compute loss gradient w.r.t. 3D means due to gradients of 2D means
         | 
| 377 | 
            +
            	// from rendering procedure
         | 
| 378 | 
            +
            	glm::vec3 dL_dmean;
         | 
| 379 | 
            +
            	float mul1 = (proj[0] * m.x + proj[4] * m.y + proj[8] * m.z + proj[12]) * m_w * m_w;
         | 
| 380 | 
            +
            	float mul2 = (proj[1] * m.x + proj[5] * m.y + proj[9] * m.z + proj[13]) * m_w * m_w;
         | 
| 381 | 
            +
            	dL_dmean.x = (proj[0] * m_w - proj[3] * mul1) * dL_dmean2D[idx].x + (proj[1] * m_w - proj[3] * mul2) * dL_dmean2D[idx].y;
         | 
| 382 | 
            +
            	dL_dmean.y = (proj[4] * m_w - proj[7] * mul1) * dL_dmean2D[idx].x + (proj[5] * m_w - proj[7] * mul2) * dL_dmean2D[idx].y;
         | 
| 383 | 
            +
            	dL_dmean.z = (proj[8] * m_w - proj[11] * mul1) * dL_dmean2D[idx].x + (proj[9] * m_w - proj[11] * mul2) * dL_dmean2D[idx].y;
         | 
| 384 | 
            +
             | 
| 385 | 
            +
            	// That's the second part of the mean gradient. Previous computation
         | 
| 386 | 
            +
            	// of cov2D and following SH conversion also affects it.
         | 
| 387 | 
            +
            	dL_dmeans[idx] += dL_dmean;
         | 
| 388 | 
            +
             | 
| 389 | 
            +
            	// Compute gradient updates due to computing colors from SHs
         | 
| 390 | 
            +
            	if (shs)
         | 
| 391 | 
            +
            		computeColorFromSH(idx, D, M, (glm::vec3*)means, *campos, shs, clamped, (glm::vec3*)dL_dcolor, (glm::vec3*)dL_dmeans, (glm::vec3*)dL_dsh);
         | 
| 392 | 
            +
             | 
| 393 | 
            +
            	// Compute gradient updates due to computing covariance from scale/rotation
         | 
| 394 | 
            +
            	if (scales)
         | 
| 395 | 
            +
            		computeCov3D(idx, scales[idx], scale_modifier, rotations[idx], dL_dcov3D, dL_dscale, dL_drot);
         | 
| 396 | 
            +
            }
         | 
| 397 | 
            +
             | 
| 398 | 
            +
            // Backward version of the rendering procedure.
         | 
| 399 | 
            +
            template <uint32_t C>
         | 
| 400 | 
            +
            __global__ void __launch_bounds__(BLOCK_X * BLOCK_Y)
         | 
| 401 | 
            +
            renderCUDA(
         | 
| 402 | 
            +
            	const uint2* __restrict__ ranges,
         | 
| 403 | 
            +
            	const uint32_t* __restrict__ point_list,
         | 
| 404 | 
            +
            	int W, int H,
         | 
| 405 | 
            +
            	const float* __restrict__ bg_color,
         | 
| 406 | 
            +
            	const float2* __restrict__ points_xy_image,
         | 
| 407 | 
            +
            	const float4* __restrict__ conic_opacity,
         | 
| 408 | 
            +
            	const float* __restrict__ colors,
         | 
| 409 | 
            +
            	const float* __restrict__ final_Ts,
         | 
| 410 | 
            +
            	const uint32_t* __restrict__ n_contrib,
         | 
| 411 | 
            +
            	const float* __restrict__ dL_dpixels,
         | 
| 412 | 
            +
            	float3* __restrict__ dL_dmean2D,
         | 
| 413 | 
            +
            	float4* __restrict__ dL_dconic2D,
         | 
| 414 | 
            +
            	float* __restrict__ dL_dopacity,
         | 
| 415 | 
            +
            	float* __restrict__ dL_dcolors)
         | 
| 416 | 
            +
            {
         | 
| 417 | 
            +
            	// We rasterize again. Compute necessary block info.
         | 
| 418 | 
            +
            	auto block = cg::this_thread_block();
         | 
| 419 | 
            +
            	const uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X;
         | 
| 420 | 
            +
            	const uint2 pix_min = { block.group_index().x * BLOCK_X, block.group_index().y * BLOCK_Y };
         | 
| 421 | 
            +
            	const uint2 pix_max = { min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y , H) };
         | 
| 422 | 
            +
            	const uint2 pix = { pix_min.x + block.thread_index().x, pix_min.y + block.thread_index().y };
         | 
| 423 | 
            +
            	const uint32_t pix_id = W * pix.y + pix.x;
         | 
| 424 | 
            +
            	const float2 pixf = { (float)pix.x, (float)pix.y };
         | 
| 425 | 
            +
             | 
| 426 | 
            +
            	const bool inside = pix.x < W&& pix.y < H;
         | 
| 427 | 
            +
            	const uint2 range = ranges[block.group_index().y * horizontal_blocks + block.group_index().x];
         | 
| 428 | 
            +
             | 
| 429 | 
            +
            	const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE);
         | 
| 430 | 
            +
             | 
| 431 | 
            +
            	bool done = !inside;
         | 
| 432 | 
            +
            	int toDo = range.y - range.x;
         | 
| 433 | 
            +
             | 
| 434 | 
            +
            	__shared__ int collected_id[BLOCK_SIZE];
         | 
| 435 | 
            +
            	__shared__ float2 collected_xy[BLOCK_SIZE];
         | 
| 436 | 
            +
            	__shared__ float4 collected_conic_opacity[BLOCK_SIZE];
         | 
| 437 | 
            +
            	__shared__ float collected_colors[C * BLOCK_SIZE];
         | 
| 438 | 
            +
             | 
| 439 | 
            +
            	// In the forward, we stored the final value for T, the
         | 
| 440 | 
            +
            	// product of all (1 - alpha) factors. 
         | 
| 441 | 
            +
            	const float T_final = inside ? final_Ts[pix_id] : 0;
         | 
| 442 | 
            +
            	float T = T_final;
         | 
| 443 | 
            +
             | 
| 444 | 
            +
            	// We start from the back. The ID of the last contributing
         | 
| 445 | 
            +
            	// Gaussian is known from each pixel from the forward.
         | 
| 446 | 
            +
            	uint32_t contributor = toDo;
         | 
| 447 | 
            +
            	const int last_contributor = inside ? n_contrib[pix_id] : 0;
         | 
| 448 | 
            +
             | 
| 449 | 
            +
            	float accum_rec[C] = { 0 };
         | 
| 450 | 
            +
            	float dL_dpixel[C];
         | 
| 451 | 
            +
            	if (inside)
         | 
| 452 | 
            +
            		for (int i = 0; i < C; i++)
         | 
| 453 | 
            +
            			dL_dpixel[i] = dL_dpixels[i * H * W + pix_id];
         | 
| 454 | 
            +
             | 
| 455 | 
            +
            	float last_alpha = 0;
         | 
| 456 | 
            +
            	float last_color[C] = { 0 };
         | 
| 457 | 
            +
             | 
| 458 | 
            +
            	// Gradient of pixel coordinate w.r.t. normalized 
         | 
| 459 | 
            +
            	// screen-space viewport corrdinates (-1 to 1)
         | 
| 460 | 
            +
            	const float ddelx_dx = 0.5 * W;
         | 
| 461 | 
            +
            	const float ddely_dy = 0.5 * H;
         | 
| 462 | 
            +
             | 
| 463 | 
            +
            	// Traverse all Gaussians
         | 
| 464 | 
            +
            	for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE)
         | 
| 465 | 
            +
            	{
         | 
| 466 | 
            +
            		// Load auxiliary data into shared memory, start in the BACK
         | 
| 467 | 
            +
            		// and load them in revers order.
         | 
| 468 | 
            +
            		block.sync();
         | 
| 469 | 
            +
            		const int progress = i * BLOCK_SIZE + block.thread_rank();
         | 
| 470 | 
            +
            		if (range.x + progress < range.y)
         | 
| 471 | 
            +
            		{
         | 
| 472 | 
            +
            			const int coll_id = point_list[range.y - progress - 1];
         | 
| 473 | 
            +
            			collected_id[block.thread_rank()] = coll_id;
         | 
| 474 | 
            +
            			collected_xy[block.thread_rank()] = points_xy_image[coll_id];
         | 
| 475 | 
            +
            			collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
         | 
| 476 | 
            +
            			for (int i = 0; i < C; i++)
         | 
| 477 | 
            +
            				collected_colors[i * BLOCK_SIZE + block.thread_rank()] = colors[coll_id * C + i];
         | 
| 478 | 
            +
            		}
         | 
| 479 | 
            +
            		block.sync();
         | 
| 480 | 
            +
             | 
| 481 | 
            +
            		// Iterate over Gaussians
         | 
| 482 | 
            +
            		for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++)
         | 
| 483 | 
            +
            		{
         | 
| 484 | 
            +
            			// Keep track of current Gaussian ID. Skip, if this one
         | 
| 485 | 
            +
            			// is behind the last contributor for this pixel.
         | 
| 486 | 
            +
            			contributor--;
         | 
| 487 | 
            +
            			if (contributor >= last_contributor)
         | 
| 488 | 
            +
            				continue;
         | 
| 489 | 
            +
             | 
| 490 | 
            +
            			// Compute blending values, as before.
         | 
| 491 | 
            +
            			const float2 xy = collected_xy[j];
         | 
| 492 | 
            +
            			const float2 d = { xy.x - pixf.x, xy.y - pixf.y };
         | 
| 493 | 
            +
            			const float4 con_o = collected_conic_opacity[j];
         | 
| 494 | 
            +
            			const float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y;
         | 
| 495 | 
            +
            			if (power > 0.0f)
         | 
| 496 | 
            +
            				continue;
         | 
| 497 | 
            +
             | 
| 498 | 
            +
            			const float G = exp(power);
         | 
| 499 | 
            +
            			const float alpha = min(0.99f, con_o.w * G);
         | 
| 500 | 
            +
            			if (alpha < 1.0f / 255.0f)
         | 
| 501 | 
            +
            				continue;
         | 
| 502 | 
            +
             | 
| 503 | 
            +
            			T = T / (1.f - alpha);
         | 
| 504 | 
            +
            			const float dchannel_dcolor = alpha * T;
         | 
| 505 | 
            +
             | 
| 506 | 
            +
            			// Propagate gradients to per-Gaussian colors and keep
         | 
| 507 | 
            +
            			// gradients w.r.t. alpha (blending factor for a Gaussian/pixel
         | 
| 508 | 
            +
            			// pair).
         | 
| 509 | 
            +
            			float dL_dalpha = 0.0f;
         | 
| 510 | 
            +
            			const int global_id = collected_id[j];
         | 
| 511 | 
            +
            			for (int ch = 0; ch < C; ch++)
         | 
| 512 | 
            +
            			{
         | 
| 513 | 
            +
            				const float c = collected_colors[ch * BLOCK_SIZE + j];
         | 
| 514 | 
            +
            				// Update last color (to be used in the next iteration)
         | 
| 515 | 
            +
            				accum_rec[ch] = last_alpha * last_color[ch] + (1.f - last_alpha) * accum_rec[ch];
         | 
| 516 | 
            +
            				last_color[ch] = c;
         | 
| 517 | 
            +
             | 
| 518 | 
            +
            				const float dL_dchannel = dL_dpixel[ch];
         | 
| 519 | 
            +
            				dL_dalpha += (c - accum_rec[ch]) * dL_dchannel;
         | 
| 520 | 
            +
            				// Update the gradients w.r.t. color of the Gaussian. 
         | 
| 521 | 
            +
            				// Atomic, since this pixel is just one of potentially
         | 
| 522 | 
            +
            				// many that were affected by this Gaussian.
         | 
| 523 | 
            +
            				atomicAdd(&(dL_dcolors[global_id * C + ch]), dchannel_dcolor * dL_dchannel);
         | 
| 524 | 
            +
            			}
         | 
| 525 | 
            +
            			dL_dalpha *= T;
         | 
| 526 | 
            +
            			// Update last alpha (to be used in the next iteration)
         | 
| 527 | 
            +
            			last_alpha = alpha;
         | 
| 528 | 
            +
             | 
| 529 | 
            +
            			// Account for fact that alpha also influences how much of
         | 
| 530 | 
            +
            			// the background color is added if nothing left to blend
         | 
| 531 | 
            +
            			float bg_dot_dpixel = 0;
         | 
| 532 | 
            +
            			for (int i = 0; i < C; i++)
         | 
| 533 | 
            +
            				bg_dot_dpixel += bg_color[i] * dL_dpixel[i];
         | 
| 534 | 
            +
            			dL_dalpha += (-T_final / (1.f - alpha)) * bg_dot_dpixel;
         | 
| 535 | 
            +
             | 
| 536 | 
            +
             | 
| 537 | 
            +
            			// Helpful reusable temporary variables
         | 
| 538 | 
            +
            			const float dL_dG = con_o.w * dL_dalpha;
         | 
| 539 | 
            +
            			const float gdx = G * d.x;
         | 
| 540 | 
            +
            			const float gdy = G * d.y;
         | 
| 541 | 
            +
            			const float dG_ddelx = -gdx * con_o.x - gdy * con_o.y;
         | 
| 542 | 
            +
            			const float dG_ddely = -gdy * con_o.z - gdx * con_o.y;
         | 
| 543 | 
            +
             | 
| 544 | 
            +
            			// Update gradients w.r.t. 2D mean position of the Gaussian
         | 
| 545 | 
            +
            			atomicAdd(&dL_dmean2D[global_id].x, dL_dG * dG_ddelx * ddelx_dx);
         | 
| 546 | 
            +
            			atomicAdd(&dL_dmean2D[global_id].y, dL_dG * dG_ddely * ddely_dy);
         | 
| 547 | 
            +
             | 
| 548 | 
            +
            			// Update gradients w.r.t. 2D covariance (2x2 matrix, symmetric)
         | 
| 549 | 
            +
            			atomicAdd(&dL_dconic2D[global_id].x, -0.5f * gdx * d.x * dL_dG);
         | 
| 550 | 
            +
            			atomicAdd(&dL_dconic2D[global_id].y, -0.5f * gdx * d.y * dL_dG);
         | 
| 551 | 
            +
            			atomicAdd(&dL_dconic2D[global_id].w, -0.5f * gdy * d.y * dL_dG);
         | 
| 552 | 
            +
             | 
| 553 | 
            +
            			// Update gradients w.r.t. opacity of the Gaussian
         | 
| 554 | 
            +
            			atomicAdd(&(dL_dopacity[global_id]), G * dL_dalpha);
         | 
| 555 | 
            +
            		}
         | 
| 556 | 
            +
            	}
         | 
| 557 | 
            +
            }
         | 
| 558 | 
            +
             | 
| 559 | 
            +
            void BACKWARD::preprocess(
         | 
| 560 | 
            +
            	int P, int D, int M,
         | 
| 561 | 
            +
            	const float3* means3D,
         | 
| 562 | 
            +
            	const int* radii,
         | 
| 563 | 
            +
            	const float* shs,
         | 
| 564 | 
            +
            	const bool* clamped,
         | 
| 565 | 
            +
            	const glm::vec3* scales,
         | 
| 566 | 
            +
            	const glm::vec4* rotations,
         | 
| 567 | 
            +
            	const float scale_modifier,
         | 
| 568 | 
            +
            	const float* cov3Ds,
         | 
| 569 | 
            +
            	const float* viewmatrix,
         | 
| 570 | 
            +
            	const float* projmatrix,
         | 
| 571 | 
            +
            	const float focal_x, float focal_y,
         | 
| 572 | 
            +
            	const float tan_fovx, float tan_fovy,
         | 
| 573 | 
            +
            	const glm::vec3* campos,
         | 
| 574 | 
            +
            	const float3* dL_dmean2D,
         | 
| 575 | 
            +
            	const float* dL_dconic,
         | 
| 576 | 
            +
            	glm::vec3* dL_dmean3D,
         | 
| 577 | 
            +
            	float* dL_dcolor,
         | 
| 578 | 
            +
            	float* dL_dcov3D,
         | 
| 579 | 
            +
            	float* dL_dsh,
         | 
| 580 | 
            +
            	glm::vec3* dL_dscale,
         | 
| 581 | 
            +
            	glm::vec4* dL_drot)
         | 
| 582 | 
            +
            {
         | 
| 583 | 
            +
            	// Propagate gradients for the path of 2D conic matrix computation. 
         | 
| 584 | 
            +
            	// Somewhat long, thus it is its own kernel rather than being part of 
         | 
| 585 | 
            +
            	// "preprocess". When done, loss gradient w.r.t. 3D means has been
         | 
| 586 | 
            +
            	// modified and gradient w.r.t. 3D covariance matrix has been computed.	
         | 
| 587 | 
            +
            	computeCov2DCUDA << <(P + 255) / 256, 256 >> > (
         | 
| 588 | 
            +
            		P,
         | 
| 589 | 
            +
            		means3D,
         | 
| 590 | 
            +
            		radii,
         | 
| 591 | 
            +
            		cov3Ds,
         | 
| 592 | 
            +
            		focal_x,
         | 
| 593 | 
            +
            		focal_y,
         | 
| 594 | 
            +
            		tan_fovx,
         | 
| 595 | 
            +
            		tan_fovy,
         | 
| 596 | 
            +
            		viewmatrix,
         | 
| 597 | 
            +
            		dL_dconic,
         | 
| 598 | 
            +
            		(float3*)dL_dmean3D,
         | 
| 599 | 
            +
            		dL_dcov3D);
         | 
| 600 | 
            +
             | 
| 601 | 
            +
            	// Propagate gradients for remaining steps: finish 3D mean gradients,
         | 
| 602 | 
            +
            	// propagate color gradients to SH (if desireD), propagate 3D covariance
         | 
| 603 | 
            +
            	// matrix gradients to scale and rotation.
         | 
| 604 | 
            +
            	preprocessCUDA<NUM_CHANNELS> << < (P + 255) / 256, 256 >> > (
         | 
| 605 | 
            +
            		P, D, M,
         | 
| 606 | 
            +
            		(float3*)means3D,
         | 
| 607 | 
            +
            		radii,
         | 
| 608 | 
            +
            		shs,
         | 
| 609 | 
            +
            		clamped,
         | 
| 610 | 
            +
            		(glm::vec3*)scales,
         | 
| 611 | 
            +
            		(glm::vec4*)rotations,
         | 
| 612 | 
            +
            		scale_modifier,
         | 
| 613 | 
            +
            		projmatrix,
         | 
| 614 | 
            +
            		campos,
         | 
| 615 | 
            +
            		(float3*)dL_dmean2D,
         | 
| 616 | 
            +
            		(glm::vec3*)dL_dmean3D,
         | 
| 617 | 
            +
            		dL_dcolor,
         | 
| 618 | 
            +
            		dL_dcov3D,
         | 
| 619 | 
            +
            		dL_dsh,
         | 
| 620 | 
            +
            		dL_dscale,
         | 
| 621 | 
            +
            		dL_drot);
         | 
| 622 | 
            +
            }
         | 
| 623 | 
            +
             | 
| 624 | 
            +
            void BACKWARD::render(
         | 
| 625 | 
            +
            	const dim3 grid, const dim3 block,
         | 
| 626 | 
            +
            	const uint2* ranges,
         | 
| 627 | 
            +
            	const uint32_t* point_list,
         | 
| 628 | 
            +
            	int W, int H,
         | 
| 629 | 
            +
            	const float* bg_color,
         | 
| 630 | 
            +
            	const float2* means2D,
         | 
| 631 | 
            +
            	const float4* conic_opacity,
         | 
| 632 | 
            +
            	const float* colors,
         | 
| 633 | 
            +
            	const float* final_Ts,
         | 
| 634 | 
            +
            	const uint32_t* n_contrib,
         | 
| 635 | 
            +
            	const float* dL_dpixels,
         | 
| 636 | 
            +
            	float3* dL_dmean2D,
         | 
| 637 | 
            +
            	float4* dL_dconic2D,
         | 
| 638 | 
            +
            	float* dL_dopacity,
         | 
| 639 | 
            +
            	float* dL_dcolors)
         | 
| 640 | 
            +
            {
         | 
| 641 | 
            +
            	renderCUDA<NUM_CHANNELS> << <grid, block >> >(
         | 
| 642 | 
            +
            		ranges,
         | 
| 643 | 
            +
            		point_list,
         | 
| 644 | 
            +
            		W, H,
         | 
| 645 | 
            +
            		bg_color,
         | 
| 646 | 
            +
            		means2D,
         | 
| 647 | 
            +
            		conic_opacity,
         | 
| 648 | 
            +
            		colors,
         | 
| 649 | 
            +
            		final_Ts,
         | 
| 650 | 
            +
            		n_contrib,
         | 
| 651 | 
            +
            		dL_dpixels,
         | 
| 652 | 
            +
            		dL_dmean2D,
         | 
| 653 | 
            +
            		dL_dconic2D,
         | 
| 654 | 
            +
            		dL_dopacity,
         | 
| 655 | 
            +
            		dL_dcolors
         | 
| 656 | 
            +
            		);
         | 
| 657 | 
            +
            }
         | 
    	
        submodules/diff-gaussian-rasterization/cuda_rasterizer/backward.h
    ADDED
    
    | @@ -0,0 +1,65 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            /*
         | 
| 2 | 
            +
             * Copyright (C) 2023, Inria
         | 
| 3 | 
            +
             * GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
             * All rights reserved.
         | 
| 5 | 
            +
             *
         | 
| 6 | 
            +
             * This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
             * under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
             *
         | 
| 9 | 
            +
             * For inquiries contact  [email protected]
         | 
| 10 | 
            +
             */
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            #ifndef CUDA_RASTERIZER_BACKWARD_H_INCLUDED
         | 
| 13 | 
            +
            #define CUDA_RASTERIZER_BACKWARD_H_INCLUDED
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            #include <cuda.h>
         | 
| 16 | 
            +
            #include "cuda_runtime.h"
         | 
| 17 | 
            +
            #include "device_launch_parameters.h"
         | 
| 18 | 
            +
            #define GLM_FORCE_CUDA
         | 
| 19 | 
            +
            #include <glm/glm.hpp>
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            namespace BACKWARD
         | 
| 22 | 
            +
            {
         | 
| 23 | 
            +
            	void render(
         | 
| 24 | 
            +
            		const dim3 grid, dim3 block,
         | 
| 25 | 
            +
            		const uint2* ranges,
         | 
| 26 | 
            +
            		const uint32_t* point_list,
         | 
| 27 | 
            +
            		int W, int H,
         | 
| 28 | 
            +
            		const float* bg_color,
         | 
| 29 | 
            +
            		const float2* means2D,
         | 
| 30 | 
            +
            		const float4* conic_opacity,
         | 
| 31 | 
            +
            		const float* colors,
         | 
| 32 | 
            +
            		const float* final_Ts,
         | 
| 33 | 
            +
            		const uint32_t* n_contrib,
         | 
| 34 | 
            +
            		const float* dL_dpixels,
         | 
| 35 | 
            +
            		float3* dL_dmean2D,
         | 
| 36 | 
            +
            		float4* dL_dconic2D,
         | 
| 37 | 
            +
            		float* dL_dopacity,
         | 
| 38 | 
            +
            		float* dL_dcolors);
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            	void preprocess(
         | 
| 41 | 
            +
            		int P, int D, int M,
         | 
| 42 | 
            +
            		const float3* means,
         | 
| 43 | 
            +
            		const int* radii,
         | 
| 44 | 
            +
            		const float* shs,
         | 
| 45 | 
            +
            		const bool* clamped,
         | 
| 46 | 
            +
            		const glm::vec3* scales,
         | 
| 47 | 
            +
            		const glm::vec4* rotations,
         | 
| 48 | 
            +
            		const float scale_modifier,
         | 
| 49 | 
            +
            		const float* cov3Ds,
         | 
| 50 | 
            +
            		const float* view,
         | 
| 51 | 
            +
            		const float* proj,
         | 
| 52 | 
            +
            		const float focal_x, float focal_y,
         | 
| 53 | 
            +
            		const float tan_fovx, float tan_fovy,
         | 
| 54 | 
            +
            		const glm::vec3* campos,
         | 
| 55 | 
            +
            		const float3* dL_dmean2D,
         | 
| 56 | 
            +
            		const float* dL_dconics,
         | 
| 57 | 
            +
            		glm::vec3* dL_dmeans,
         | 
| 58 | 
            +
            		float* dL_dcolor,
         | 
| 59 | 
            +
            		float* dL_dcov3D,
         | 
| 60 | 
            +
            		float* dL_dsh,
         | 
| 61 | 
            +
            		glm::vec3* dL_dscale,
         | 
| 62 | 
            +
            		glm::vec4* dL_drot);
         | 
| 63 | 
            +
            }
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            #endif
         | 
    	
        submodules/diff-gaussian-rasterization/cuda_rasterizer/config.h
    ADDED
    
    | @@ -0,0 +1,19 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            /*
         | 
| 2 | 
            +
             * Copyright (C) 2023, Inria
         | 
| 3 | 
            +
             * GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
             * All rights reserved.
         | 
| 5 | 
            +
             *
         | 
| 6 | 
            +
             * This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
             * under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
             *
         | 
| 9 | 
            +
             * For inquiries contact  [email protected]
         | 
| 10 | 
            +
             */
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            #ifndef CUDA_RASTERIZER_CONFIG_H_INCLUDED
         | 
| 13 | 
            +
            #define CUDA_RASTERIZER_CONFIG_H_INCLUDED
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            #define NUM_CHANNELS 3 // Default 3, RGB
         | 
| 16 | 
            +
            #define BLOCK_X 16
         | 
| 17 | 
            +
            #define BLOCK_Y 16
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            #endif
         | 
    	
        submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.cu
    ADDED
    
    | @@ -0,0 +1,455 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            /*
         | 
| 2 | 
            +
             * Copyright (C) 2023, Inria
         | 
| 3 | 
            +
             * GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
             * All rights reserved.
         | 
| 5 | 
            +
             *
         | 
| 6 | 
            +
             * This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
             * under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
             *
         | 
| 9 | 
            +
             * For inquiries contact  [email protected]
         | 
| 10 | 
            +
             */
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            #include "forward.h"
         | 
| 13 | 
            +
            #include "auxiliary.h"
         | 
| 14 | 
            +
            #include <cooperative_groups.h>
         | 
| 15 | 
            +
            #include <cooperative_groups/reduce.h>
         | 
| 16 | 
            +
            namespace cg = cooperative_groups;
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            // Forward method for converting the input spherical harmonics
         | 
| 19 | 
            +
            // coefficients of each Gaussian to a simple RGB color.
         | 
| 20 | 
            +
            __device__ glm::vec3 computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, bool* clamped)
         | 
| 21 | 
            +
            {
         | 
| 22 | 
            +
            	// The implementation is loosely based on code for 
         | 
| 23 | 
            +
            	// "Differentiable Point-Based Radiance Fields for 
         | 
| 24 | 
            +
            	// Efficient View Synthesis" by Zhang et al. (2022)
         | 
| 25 | 
            +
            	glm::vec3 pos = means[idx];
         | 
| 26 | 
            +
            	glm::vec3 dir = pos - campos;
         | 
| 27 | 
            +
            	dir = dir / glm::length(dir);
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            	glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs;
         | 
| 30 | 
            +
            	glm::vec3 result = SH_C0 * sh[0];
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            	if (deg > 0)
         | 
| 33 | 
            +
            	{
         | 
| 34 | 
            +
            		float x = dir.x;
         | 
| 35 | 
            +
            		float y = dir.y;
         | 
| 36 | 
            +
            		float z = dir.z;
         | 
| 37 | 
            +
            		result = result - SH_C1 * y * sh[1] + SH_C1 * z * sh[2] - SH_C1 * x * sh[3];
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            		if (deg > 1)
         | 
| 40 | 
            +
            		{
         | 
| 41 | 
            +
            			float xx = x * x, yy = y * y, zz = z * z;
         | 
| 42 | 
            +
            			float xy = x * y, yz = y * z, xz = x * z;
         | 
| 43 | 
            +
            			result = result +
         | 
| 44 | 
            +
            				SH_C2[0] * xy * sh[4] +
         | 
| 45 | 
            +
            				SH_C2[1] * yz * sh[5] +
         | 
| 46 | 
            +
            				SH_C2[2] * (2.0f * zz - xx - yy) * sh[6] +
         | 
| 47 | 
            +
            				SH_C2[3] * xz * sh[7] +
         | 
| 48 | 
            +
            				SH_C2[4] * (xx - yy) * sh[8];
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            			if (deg > 2)
         | 
| 51 | 
            +
            			{
         | 
| 52 | 
            +
            				result = result +
         | 
| 53 | 
            +
            					SH_C3[0] * y * (3.0f * xx - yy) * sh[9] +
         | 
| 54 | 
            +
            					SH_C3[1] * xy * z * sh[10] +
         | 
| 55 | 
            +
            					SH_C3[2] * y * (4.0f * zz - xx - yy) * sh[11] +
         | 
| 56 | 
            +
            					SH_C3[3] * z * (2.0f * zz - 3.0f * xx - 3.0f * yy) * sh[12] +
         | 
| 57 | 
            +
            					SH_C3[4] * x * (4.0f * zz - xx - yy) * sh[13] +
         | 
| 58 | 
            +
            					SH_C3[5] * z * (xx - yy) * sh[14] +
         | 
| 59 | 
            +
            					SH_C3[6] * x * (xx - 3.0f * yy) * sh[15];
         | 
| 60 | 
            +
            			}
         | 
| 61 | 
            +
            		}
         | 
| 62 | 
            +
            	}
         | 
| 63 | 
            +
            	result += 0.5f;
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            	// RGB colors are clamped to positive values. If values are
         | 
| 66 | 
            +
            	// clamped, we need to keep track of this for the backward pass.
         | 
| 67 | 
            +
            	clamped[3 * idx + 0] = (result.x < 0);
         | 
| 68 | 
            +
            	clamped[3 * idx + 1] = (result.y < 0);
         | 
| 69 | 
            +
            	clamped[3 * idx + 2] = (result.z < 0);
         | 
| 70 | 
            +
            	return glm::max(result, 0.0f);
         | 
| 71 | 
            +
            }
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            // Forward version of 2D covariance matrix computation
         | 
| 74 | 
            +
            __device__ float3 computeCov2D(const float3& mean, float focal_x, float focal_y, float tan_fovx, float tan_fovy, const float* cov3D, const float* viewmatrix)
         | 
| 75 | 
            +
            {
         | 
| 76 | 
            +
            	// The following models the steps outlined by equations 29
         | 
| 77 | 
            +
            	// and 31 in "EWA Splatting" (Zwicker et al., 2002). 
         | 
| 78 | 
            +
            	// Additionally considers aspect / scaling of viewport.
         | 
| 79 | 
            +
            	// Transposes used to account for row-/column-major conventions.
         | 
| 80 | 
            +
            	float3 t = transformPoint4x3(mean, viewmatrix);
         | 
| 81 | 
            +
             | 
| 82 | 
            +
            	const float limx = 1.3f * tan_fovx;
         | 
| 83 | 
            +
            	const float limy = 1.3f * tan_fovy;
         | 
| 84 | 
            +
            	const float txtz = t.x / t.z;
         | 
| 85 | 
            +
            	const float tytz = t.y / t.z;
         | 
| 86 | 
            +
            	t.x = min(limx, max(-limx, txtz)) * t.z;
         | 
| 87 | 
            +
            	t.y = min(limy, max(-limy, tytz)) * t.z;
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            	glm::mat3 J = glm::mat3(
         | 
| 90 | 
            +
            		focal_x / t.z, 0.0f, -(focal_x * t.x) / (t.z * t.z),
         | 
| 91 | 
            +
            		0.0f, focal_y / t.z, -(focal_y * t.y) / (t.z * t.z),
         | 
| 92 | 
            +
            		0, 0, 0);
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            	glm::mat3 W = glm::mat3(
         | 
| 95 | 
            +
            		viewmatrix[0], viewmatrix[4], viewmatrix[8],
         | 
| 96 | 
            +
            		viewmatrix[1], viewmatrix[5], viewmatrix[9],
         | 
| 97 | 
            +
            		viewmatrix[2], viewmatrix[6], viewmatrix[10]);
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            	glm::mat3 T = W * J;
         | 
| 100 | 
            +
             | 
| 101 | 
            +
            	glm::mat3 Vrk = glm::mat3(
         | 
| 102 | 
            +
            		cov3D[0], cov3D[1], cov3D[2],
         | 
| 103 | 
            +
            		cov3D[1], cov3D[3], cov3D[4],
         | 
| 104 | 
            +
            		cov3D[2], cov3D[4], cov3D[5]);
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            	glm::mat3 cov = glm::transpose(T) * glm::transpose(Vrk) * T;
         | 
| 107 | 
            +
             | 
| 108 | 
            +
            	// Apply low-pass filter: every Gaussian should be at least
         | 
| 109 | 
            +
            	// one pixel wide/high. Discard 3rd row and column.
         | 
| 110 | 
            +
            	cov[0][0] += 0.3f;
         | 
| 111 | 
            +
            	cov[1][1] += 0.3f;
         | 
| 112 | 
            +
            	return { float(cov[0][0]), float(cov[0][1]), float(cov[1][1]) };
         | 
| 113 | 
            +
            }
         | 
| 114 | 
            +
             | 
| 115 | 
            +
            // Forward method for converting scale and rotation properties of each
         | 
| 116 | 
            +
            // Gaussian to a 3D covariance matrix in world space. Also takes care
         | 
| 117 | 
            +
            // of quaternion normalization.
         | 
| 118 | 
            +
            __device__ void computeCov3D(const glm::vec3 scale, float mod, const glm::vec4 rot, float* cov3D)
         | 
| 119 | 
            +
            {
         | 
| 120 | 
            +
            	// Create scaling matrix
         | 
| 121 | 
            +
            	glm::mat3 S = glm::mat3(1.0f);
         | 
| 122 | 
            +
            	S[0][0] = mod * scale.x;
         | 
| 123 | 
            +
            	S[1][1] = mod * scale.y;
         | 
| 124 | 
            +
            	S[2][2] = mod * scale.z;
         | 
| 125 | 
            +
             | 
| 126 | 
            +
            	// Normalize quaternion to get valid rotation
         | 
| 127 | 
            +
            	glm::vec4 q = rot;// / glm::length(rot);
         | 
| 128 | 
            +
            	float r = q.x;
         | 
| 129 | 
            +
            	float x = q.y;
         | 
| 130 | 
            +
            	float y = q.z;
         | 
| 131 | 
            +
            	float z = q.w;
         | 
| 132 | 
            +
             | 
| 133 | 
            +
            	// Compute rotation matrix from quaternion
         | 
| 134 | 
            +
            	glm::mat3 R = glm::mat3(
         | 
| 135 | 
            +
            		1.f - 2.f * (y * y + z * z), 2.f * (x * y - r * z), 2.f * (x * z + r * y),
         | 
| 136 | 
            +
            		2.f * (x * y + r * z), 1.f - 2.f * (x * x + z * z), 2.f * (y * z - r * x),
         | 
| 137 | 
            +
            		2.f * (x * z - r * y), 2.f * (y * z + r * x), 1.f - 2.f * (x * x + y * y)
         | 
| 138 | 
            +
            	);
         | 
| 139 | 
            +
             | 
| 140 | 
            +
            	glm::mat3 M = S * R;
         | 
| 141 | 
            +
             | 
| 142 | 
            +
            	// Compute 3D world covariance matrix Sigma
         | 
| 143 | 
            +
            	glm::mat3 Sigma = glm::transpose(M) * M;
         | 
| 144 | 
            +
             | 
| 145 | 
            +
            	// Covariance is symmetric, only store upper right
         | 
| 146 | 
            +
            	cov3D[0] = Sigma[0][0];
         | 
| 147 | 
            +
            	cov3D[1] = Sigma[0][1];
         | 
| 148 | 
            +
            	cov3D[2] = Sigma[0][2];
         | 
| 149 | 
            +
            	cov3D[3] = Sigma[1][1];
         | 
| 150 | 
            +
            	cov3D[4] = Sigma[1][2];
         | 
| 151 | 
            +
            	cov3D[5] = Sigma[2][2];
         | 
| 152 | 
            +
            }
         | 
| 153 | 
            +
             | 
| 154 | 
            +
            // Perform initial steps for each Gaussian prior to rasterization.
         | 
| 155 | 
            +
            template<int C>
         | 
| 156 | 
            +
            __global__ void preprocessCUDA(int P, int D, int M,
         | 
| 157 | 
            +
            	const float* orig_points,
         | 
| 158 | 
            +
            	const glm::vec3* scales,
         | 
| 159 | 
            +
            	const float scale_modifier,
         | 
| 160 | 
            +
            	const glm::vec4* rotations,
         | 
| 161 | 
            +
            	const float* opacities,
         | 
| 162 | 
            +
            	const float* shs,
         | 
| 163 | 
            +
            	bool* clamped,
         | 
| 164 | 
            +
            	const float* cov3D_precomp,
         | 
| 165 | 
            +
            	const float* colors_precomp,
         | 
| 166 | 
            +
            	const float* viewmatrix,
         | 
| 167 | 
            +
            	const float* projmatrix,
         | 
| 168 | 
            +
            	const glm::vec3* cam_pos,
         | 
| 169 | 
            +
            	const int W, int H,
         | 
| 170 | 
            +
            	const float tan_fovx, float tan_fovy,
         | 
| 171 | 
            +
            	const float focal_x, float focal_y,
         | 
| 172 | 
            +
            	int* radii,
         | 
| 173 | 
            +
            	float2* points_xy_image,
         | 
| 174 | 
            +
            	float* depths,
         | 
| 175 | 
            +
            	float* cov3Ds,
         | 
| 176 | 
            +
            	float* rgb,
         | 
| 177 | 
            +
            	float4* conic_opacity,
         | 
| 178 | 
            +
            	const dim3 grid,
         | 
| 179 | 
            +
            	uint32_t* tiles_touched,
         | 
| 180 | 
            +
            	bool prefiltered)
         | 
| 181 | 
            +
            {
         | 
| 182 | 
            +
            	auto idx = cg::this_grid().thread_rank();
         | 
| 183 | 
            +
            	if (idx >= P)
         | 
| 184 | 
            +
            		return;
         | 
| 185 | 
            +
             | 
| 186 | 
            +
            	// Initialize radius and touched tiles to 0. If this isn't changed,
         | 
| 187 | 
            +
            	// this Gaussian will not be processed further.
         | 
| 188 | 
            +
            	radii[idx] = 0;
         | 
| 189 | 
            +
            	tiles_touched[idx] = 0;
         | 
| 190 | 
            +
             | 
| 191 | 
            +
            	// Perform near culling, quit if outside.
         | 
| 192 | 
            +
            	float3 p_view;
         | 
| 193 | 
            +
            	if (!in_frustum(idx, orig_points, viewmatrix, projmatrix, prefiltered, p_view))
         | 
| 194 | 
            +
            		return;
         | 
| 195 | 
            +
             | 
| 196 | 
            +
            	// Transform point by projecting
         | 
| 197 | 
            +
            	float3 p_orig = { orig_points[3 * idx], orig_points[3 * idx + 1], orig_points[3 * idx + 2] };
         | 
| 198 | 
            +
            	float4 p_hom = transformPoint4x4(p_orig, projmatrix);
         | 
| 199 | 
            +
            	float p_w = 1.0f / (p_hom.w + 0.0000001f);
         | 
| 200 | 
            +
            	float3 p_proj = { p_hom.x * p_w, p_hom.y * p_w, p_hom.z * p_w };
         | 
| 201 | 
            +
             | 
| 202 | 
            +
            	// If 3D covariance matrix is precomputed, use it, otherwise compute
         | 
| 203 | 
            +
            	// from scaling and rotation parameters. 
         | 
| 204 | 
            +
            	const float* cov3D;
         | 
| 205 | 
            +
            	if (cov3D_precomp != nullptr)
         | 
| 206 | 
            +
            	{
         | 
| 207 | 
            +
            		cov3D = cov3D_precomp + idx * 6;
         | 
| 208 | 
            +
            	}
         | 
| 209 | 
            +
            	else
         | 
| 210 | 
            +
            	{
         | 
| 211 | 
            +
            		computeCov3D(scales[idx], scale_modifier, rotations[idx], cov3Ds + idx * 6);
         | 
| 212 | 
            +
            		cov3D = cov3Ds + idx * 6;
         | 
| 213 | 
            +
            	}
         | 
| 214 | 
            +
             | 
| 215 | 
            +
            	// Compute 2D screen-space covariance matrix
         | 
| 216 | 
            +
            	float3 cov = computeCov2D(p_orig, focal_x, focal_y, tan_fovx, tan_fovy, cov3D, viewmatrix);
         | 
| 217 | 
            +
             | 
| 218 | 
            +
            	// Invert covariance (EWA algorithm)
         | 
| 219 | 
            +
            	float det = (cov.x * cov.z - cov.y * cov.y);
         | 
| 220 | 
            +
            	if (det == 0.0f)
         | 
| 221 | 
            +
            		return;
         | 
| 222 | 
            +
            	float det_inv = 1.f / det;
         | 
| 223 | 
            +
            	float3 conic = { cov.z * det_inv, -cov.y * det_inv, cov.x * det_inv };
         | 
| 224 | 
            +
             | 
| 225 | 
            +
            	// Compute extent in screen space (by finding eigenvalues of
         | 
| 226 | 
            +
            	// 2D covariance matrix). Use extent to compute a bounding rectangle
         | 
| 227 | 
            +
            	// of screen-space tiles that this Gaussian overlaps with. Quit if
         | 
| 228 | 
            +
            	// rectangle covers 0 tiles. 
         | 
| 229 | 
            +
            	float mid = 0.5f * (cov.x + cov.z);
         | 
| 230 | 
            +
            	float lambda1 = mid + sqrt(max(0.1f, mid * mid - det));
         | 
| 231 | 
            +
            	float lambda2 = mid - sqrt(max(0.1f, mid * mid - det));
         | 
| 232 | 
            +
            	float my_radius = ceil(3.f * sqrt(max(lambda1, lambda2)));
         | 
| 233 | 
            +
            	float2 point_image = { ndc2Pix(p_proj.x, W), ndc2Pix(p_proj.y, H) };
         | 
| 234 | 
            +
            	uint2 rect_min, rect_max;
         | 
| 235 | 
            +
            	getRect(point_image, my_radius, rect_min, rect_max, grid);
         | 
| 236 | 
            +
            	if ((rect_max.x - rect_min.x) * (rect_max.y - rect_min.y) == 0)
         | 
| 237 | 
            +
            		return;
         | 
| 238 | 
            +
             | 
| 239 | 
            +
            	// If colors have been precomputed, use them, otherwise convert
         | 
| 240 | 
            +
            	// spherical harmonics coefficients to RGB color.
         | 
| 241 | 
            +
            	if (colors_precomp == nullptr)
         | 
| 242 | 
            +
            	{
         | 
| 243 | 
            +
            		glm::vec3 result = computeColorFromSH(idx, D, M, (glm::vec3*)orig_points, *cam_pos, shs, clamped);
         | 
| 244 | 
            +
            		rgb[idx * C + 0] = result.x;
         | 
| 245 | 
            +
            		rgb[idx * C + 1] = result.y;
         | 
| 246 | 
            +
            		rgb[idx * C + 2] = result.z;
         | 
| 247 | 
            +
            	}
         | 
| 248 | 
            +
             | 
| 249 | 
            +
            	// Store some useful helper data for the next steps.
         | 
| 250 | 
            +
            	depths[idx] = p_view.z;
         | 
| 251 | 
            +
            	radii[idx] = my_radius;
         | 
| 252 | 
            +
            	points_xy_image[idx] = point_image;
         | 
| 253 | 
            +
            	// Inverse 2D covariance and opacity neatly pack into one float4
         | 
| 254 | 
            +
            	conic_opacity[idx] = { conic.x, conic.y, conic.z, opacities[idx] };
         | 
| 255 | 
            +
            	tiles_touched[idx] = (rect_max.y - rect_min.y) * (rect_max.x - rect_min.x);
         | 
| 256 | 
            +
            }
         | 
| 257 | 
            +
             | 
| 258 | 
            +
            // Main rasterization method. Collaboratively works on one tile per
         | 
| 259 | 
            +
            // block, each thread treats one pixel. Alternates between fetching 
         | 
| 260 | 
            +
            // and rasterizing data.
         | 
| 261 | 
            +
            template <uint32_t CHANNELS>
         | 
| 262 | 
            +
            __global__ void __launch_bounds__(BLOCK_X * BLOCK_Y)
         | 
| 263 | 
            +
            renderCUDA(
         | 
| 264 | 
            +
            	const uint2* __restrict__ ranges,
         | 
| 265 | 
            +
            	const uint32_t* __restrict__ point_list,
         | 
| 266 | 
            +
            	int W, int H,
         | 
| 267 | 
            +
            	const float2* __restrict__ points_xy_image,
         | 
| 268 | 
            +
            	const float* __restrict__ features,
         | 
| 269 | 
            +
            	const float4* __restrict__ conic_opacity,
         | 
| 270 | 
            +
            	float* __restrict__ final_T,
         | 
| 271 | 
            +
            	uint32_t* __restrict__ n_contrib,
         | 
| 272 | 
            +
            	const float* __restrict__ bg_color,
         | 
| 273 | 
            +
            	float* __restrict__ out_color)
         | 
| 274 | 
            +
            {
         | 
| 275 | 
            +
            	// Identify current tile and associated min/max pixel range.
         | 
| 276 | 
            +
            	auto block = cg::this_thread_block();
         | 
| 277 | 
            +
            	uint32_t horizontal_blocks = (W + BLOCK_X - 1) / BLOCK_X;
         | 
| 278 | 
            +
            	uint2 pix_min = { block.group_index().x * BLOCK_X, block.group_index().y * BLOCK_Y };
         | 
| 279 | 
            +
            	uint2 pix_max = { min(pix_min.x + BLOCK_X, W), min(pix_min.y + BLOCK_Y , H) };
         | 
| 280 | 
            +
            	uint2 pix = { pix_min.x + block.thread_index().x, pix_min.y + block.thread_index().y };
         | 
| 281 | 
            +
            	uint32_t pix_id = W * pix.y + pix.x;
         | 
| 282 | 
            +
            	float2 pixf = { (float)pix.x, (float)pix.y };
         | 
| 283 | 
            +
             | 
| 284 | 
            +
            	// Check if this thread is associated with a valid pixel or outside.
         | 
| 285 | 
            +
            	bool inside = pix.x < W&& pix.y < H;
         | 
| 286 | 
            +
            	// Done threads can help with fetching, but don't rasterize
         | 
| 287 | 
            +
            	bool done = !inside;
         | 
| 288 | 
            +
             | 
| 289 | 
            +
            	// Load start/end range of IDs to process in bit sorted list.
         | 
| 290 | 
            +
            	uint2 range = ranges[block.group_index().y * horizontal_blocks + block.group_index().x];
         | 
| 291 | 
            +
            	const int rounds = ((range.y - range.x + BLOCK_SIZE - 1) / BLOCK_SIZE);
         | 
| 292 | 
            +
            	int toDo = range.y - range.x;
         | 
| 293 | 
            +
             | 
| 294 | 
            +
            	// Allocate storage for batches of collectively fetched data.
         | 
| 295 | 
            +
            	__shared__ int collected_id[BLOCK_SIZE];
         | 
| 296 | 
            +
            	__shared__ float2 collected_xy[BLOCK_SIZE];
         | 
| 297 | 
            +
            	__shared__ float4 collected_conic_opacity[BLOCK_SIZE];
         | 
| 298 | 
            +
             | 
| 299 | 
            +
            	// Initialize helper variables
         | 
| 300 | 
            +
            	float T = 1.0f;
         | 
| 301 | 
            +
            	uint32_t contributor = 0;
         | 
| 302 | 
            +
            	uint32_t last_contributor = 0;
         | 
| 303 | 
            +
            	float C[CHANNELS] = { 0 };
         | 
| 304 | 
            +
             | 
| 305 | 
            +
            	// Iterate over batches until all done or range is complete
         | 
| 306 | 
            +
            	for (int i = 0; i < rounds; i++, toDo -= BLOCK_SIZE)
         | 
| 307 | 
            +
            	{
         | 
| 308 | 
            +
            		// End if entire block votes that it is done rasterizing
         | 
| 309 | 
            +
            		int num_done = __syncthreads_count(done);
         | 
| 310 | 
            +
            		if (num_done == BLOCK_SIZE)
         | 
| 311 | 
            +
            			break;
         | 
| 312 | 
            +
             | 
| 313 | 
            +
            		// Collectively fetch per-Gaussian data from global to shared
         | 
| 314 | 
            +
            		int progress = i * BLOCK_SIZE + block.thread_rank();
         | 
| 315 | 
            +
            		if (range.x + progress < range.y)
         | 
| 316 | 
            +
            		{
         | 
| 317 | 
            +
            			int coll_id = point_list[range.x + progress];
         | 
| 318 | 
            +
            			collected_id[block.thread_rank()] = coll_id;
         | 
| 319 | 
            +
            			collected_xy[block.thread_rank()] = points_xy_image[coll_id];
         | 
| 320 | 
            +
            			collected_conic_opacity[block.thread_rank()] = conic_opacity[coll_id];
         | 
| 321 | 
            +
            		}
         | 
| 322 | 
            +
            		block.sync();
         | 
| 323 | 
            +
             | 
| 324 | 
            +
            		// Iterate over current batch
         | 
| 325 | 
            +
            		for (int j = 0; !done && j < min(BLOCK_SIZE, toDo); j++)
         | 
| 326 | 
            +
            		{
         | 
| 327 | 
            +
            			// Keep track of current position in range
         | 
| 328 | 
            +
            			contributor++;
         | 
| 329 | 
            +
             | 
| 330 | 
            +
            			// Resample using conic matrix (cf. "Surface 
         | 
| 331 | 
            +
            			// Splatting" by Zwicker et al., 2001)
         | 
| 332 | 
            +
            			float2 xy = collected_xy[j];
         | 
| 333 | 
            +
            			float2 d = { xy.x - pixf.x, xy.y - pixf.y };
         | 
| 334 | 
            +
            			float4 con_o = collected_conic_opacity[j];
         | 
| 335 | 
            +
            			float power = -0.5f * (con_o.x * d.x * d.x + con_o.z * d.y * d.y) - con_o.y * d.x * d.y;
         | 
| 336 | 
            +
            			if (power > 0.0f)
         | 
| 337 | 
            +
            				continue;
         | 
| 338 | 
            +
             | 
| 339 | 
            +
            			// Eq. (2) from 3D Gaussian splatting paper.
         | 
| 340 | 
            +
            			// Obtain alpha by multiplying with Gaussian opacity
         | 
| 341 | 
            +
            			// and its exponential falloff from mean.
         | 
| 342 | 
            +
            			// Avoid numerical instabilities (see paper appendix). 
         | 
| 343 | 
            +
            			float alpha = min(0.99f, con_o.w * exp(power));
         | 
| 344 | 
            +
            			if (alpha < 1.0f / 255.0f)
         | 
| 345 | 
            +
            				continue;
         | 
| 346 | 
            +
            			float test_T = T * (1 - alpha);
         | 
| 347 | 
            +
            			if (test_T < 0.0001f)
         | 
| 348 | 
            +
            			{
         | 
| 349 | 
            +
            				done = true;
         | 
| 350 | 
            +
            				continue;
         | 
| 351 | 
            +
            			}
         | 
| 352 | 
            +
             | 
| 353 | 
            +
            			// Eq. (3) from 3D Gaussian splatting paper.
         | 
| 354 | 
            +
            			for (int ch = 0; ch < CHANNELS; ch++)
         | 
| 355 | 
            +
            				C[ch] += features[collected_id[j] * CHANNELS + ch] * alpha * T;
         | 
| 356 | 
            +
             | 
| 357 | 
            +
            			T = test_T;
         | 
| 358 | 
            +
             | 
| 359 | 
            +
            			// Keep track of last range entry to update this
         | 
| 360 | 
            +
            			// pixel.
         | 
| 361 | 
            +
            			last_contributor = contributor;
         | 
| 362 | 
            +
            		}
         | 
| 363 | 
            +
            	}
         | 
| 364 | 
            +
             | 
| 365 | 
            +
            	// All threads that treat valid pixel write out their final
         | 
| 366 | 
            +
            	// rendering data to the frame and auxiliary buffers.
         | 
| 367 | 
            +
            	if (inside)
         | 
| 368 | 
            +
            	{
         | 
| 369 | 
            +
            		final_T[pix_id] = T;
         | 
| 370 | 
            +
            		n_contrib[pix_id] = last_contributor;
         | 
| 371 | 
            +
            		for (int ch = 0; ch < CHANNELS; ch++)
         | 
| 372 | 
            +
            			out_color[ch * H * W + pix_id] = C[ch] + T * bg_color[ch];
         | 
| 373 | 
            +
            	}
         | 
| 374 | 
            +
            }
         | 
| 375 | 
            +
             | 
| 376 | 
            +
            void FORWARD::render(
         | 
| 377 | 
            +
            	const dim3 grid, dim3 block,
         | 
| 378 | 
            +
            	const uint2* ranges,
         | 
| 379 | 
            +
            	const uint32_t* point_list,
         | 
| 380 | 
            +
            	int W, int H,
         | 
| 381 | 
            +
            	const float2* means2D,
         | 
| 382 | 
            +
            	const float* colors,
         | 
| 383 | 
            +
            	const float4* conic_opacity,
         | 
| 384 | 
            +
            	float* final_T,
         | 
| 385 | 
            +
            	uint32_t* n_contrib,
         | 
| 386 | 
            +
            	const float* bg_color,
         | 
| 387 | 
            +
            	float* out_color)
         | 
| 388 | 
            +
            {
         | 
| 389 | 
            +
            	renderCUDA<NUM_CHANNELS> << <grid, block >> > (
         | 
| 390 | 
            +
            		ranges,
         | 
| 391 | 
            +
            		point_list,
         | 
| 392 | 
            +
            		W, H,
         | 
| 393 | 
            +
            		means2D,
         | 
| 394 | 
            +
            		colors,
         | 
| 395 | 
            +
            		conic_opacity,
         | 
| 396 | 
            +
            		final_T,
         | 
| 397 | 
            +
            		n_contrib,
         | 
| 398 | 
            +
            		bg_color,
         | 
| 399 | 
            +
            		out_color);
         | 
| 400 | 
            +
            }
         | 
| 401 | 
            +
             | 
| 402 | 
            +
            void FORWARD::preprocess(int P, int D, int M,
         | 
| 403 | 
            +
            	const float* means3D,
         | 
| 404 | 
            +
            	const glm::vec3* scales,
         | 
| 405 | 
            +
            	const float scale_modifier,
         | 
| 406 | 
            +
            	const glm::vec4* rotations,
         | 
| 407 | 
            +
            	const float* opacities,
         | 
| 408 | 
            +
            	const float* shs,
         | 
| 409 | 
            +
            	bool* clamped,
         | 
| 410 | 
            +
            	const float* cov3D_precomp,
         | 
| 411 | 
            +
            	const float* colors_precomp,
         | 
| 412 | 
            +
            	const float* viewmatrix,
         | 
| 413 | 
            +
            	const float* projmatrix,
         | 
| 414 | 
            +
            	const glm::vec3* cam_pos,
         | 
| 415 | 
            +
            	const int W, int H,
         | 
| 416 | 
            +
            	const float focal_x, float focal_y,
         | 
| 417 | 
            +
            	const float tan_fovx, float tan_fovy,
         | 
| 418 | 
            +
            	int* radii,
         | 
| 419 | 
            +
            	float2* means2D,
         | 
| 420 | 
            +
            	float* depths,
         | 
| 421 | 
            +
            	float* cov3Ds,
         | 
| 422 | 
            +
            	float* rgb,
         | 
| 423 | 
            +
            	float4* conic_opacity,
         | 
| 424 | 
            +
            	const dim3 grid,
         | 
| 425 | 
            +
            	uint32_t* tiles_touched,
         | 
| 426 | 
            +
            	bool prefiltered)
         | 
| 427 | 
            +
            {
         | 
| 428 | 
            +
            	preprocessCUDA<NUM_CHANNELS> << <(P + 255) / 256, 256 >> > (
         | 
| 429 | 
            +
            		P, D, M,
         | 
| 430 | 
            +
            		means3D,
         | 
| 431 | 
            +
            		scales,
         | 
| 432 | 
            +
            		scale_modifier,
         | 
| 433 | 
            +
            		rotations,
         | 
| 434 | 
            +
            		opacities,
         | 
| 435 | 
            +
            		shs,
         | 
| 436 | 
            +
            		clamped,
         | 
| 437 | 
            +
            		cov3D_precomp,
         | 
| 438 | 
            +
            		colors_precomp,
         | 
| 439 | 
            +
            		viewmatrix, 
         | 
| 440 | 
            +
            		projmatrix,
         | 
| 441 | 
            +
            		cam_pos,
         | 
| 442 | 
            +
            		W, H,
         | 
| 443 | 
            +
            		tan_fovx, tan_fovy,
         | 
| 444 | 
            +
            		focal_x, focal_y,
         | 
| 445 | 
            +
            		radii,
         | 
| 446 | 
            +
            		means2D,
         | 
| 447 | 
            +
            		depths,
         | 
| 448 | 
            +
            		cov3Ds,
         | 
| 449 | 
            +
            		rgb,
         | 
| 450 | 
            +
            		conic_opacity,
         | 
| 451 | 
            +
            		grid,
         | 
| 452 | 
            +
            		tiles_touched,
         | 
| 453 | 
            +
            		prefiltered
         | 
| 454 | 
            +
            		);
         | 
| 455 | 
            +
            }
         | 
    	
        submodules/diff-gaussian-rasterization/cuda_rasterizer/forward.h
    ADDED
    
    | @@ -0,0 +1,66 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            /*
         | 
| 2 | 
            +
             * Copyright (C) 2023, Inria
         | 
| 3 | 
            +
             * GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
             * All rights reserved.
         | 
| 5 | 
            +
             *
         | 
| 6 | 
            +
             * This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
             * under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
             *
         | 
| 9 | 
            +
             * For inquiries contact  [email protected]
         | 
| 10 | 
            +
             */
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            #ifndef CUDA_RASTERIZER_FORWARD_H_INCLUDED
         | 
| 13 | 
            +
            #define CUDA_RASTERIZER_FORWARD_H_INCLUDED
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            #include <cuda.h>
         | 
| 16 | 
            +
            #include "cuda_runtime.h"
         | 
| 17 | 
            +
            #include "device_launch_parameters.h"
         | 
| 18 | 
            +
            #define GLM_FORCE_CUDA
         | 
| 19 | 
            +
            #include <glm/glm.hpp>
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            namespace FORWARD
         | 
| 22 | 
            +
            {
         | 
| 23 | 
            +
            	// Perform initial steps for each Gaussian prior to rasterization.
         | 
| 24 | 
            +
            	void preprocess(int P, int D, int M,
         | 
| 25 | 
            +
            		const float* orig_points,
         | 
| 26 | 
            +
            		const glm::vec3* scales,
         | 
| 27 | 
            +
            		const float scale_modifier,
         | 
| 28 | 
            +
            		const glm::vec4* rotations,
         | 
| 29 | 
            +
            		const float* opacities,
         | 
| 30 | 
            +
            		const float* shs,
         | 
| 31 | 
            +
            		bool* clamped,
         | 
| 32 | 
            +
            		const float* cov3D_precomp,
         | 
| 33 | 
            +
            		const float* colors_precomp,
         | 
| 34 | 
            +
            		const float* viewmatrix,
         | 
| 35 | 
            +
            		const float* projmatrix,
         | 
| 36 | 
            +
            		const glm::vec3* cam_pos,
         | 
| 37 | 
            +
            		const int W, int H,
         | 
| 38 | 
            +
            		const float focal_x, float focal_y,
         | 
| 39 | 
            +
            		const float tan_fovx, float tan_fovy,
         | 
| 40 | 
            +
            		int* radii,
         | 
| 41 | 
            +
            		float2* points_xy_image,
         | 
| 42 | 
            +
            		float* depths,
         | 
| 43 | 
            +
            		float* cov3Ds,
         | 
| 44 | 
            +
            		float* colors,
         | 
| 45 | 
            +
            		float4* conic_opacity,
         | 
| 46 | 
            +
            		const dim3 grid,
         | 
| 47 | 
            +
            		uint32_t* tiles_touched,
         | 
| 48 | 
            +
            		bool prefiltered);
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            	// Main rasterization method.
         | 
| 51 | 
            +
            	void render(
         | 
| 52 | 
            +
            		const dim3 grid, dim3 block,
         | 
| 53 | 
            +
            		const uint2* ranges,
         | 
| 54 | 
            +
            		const uint32_t* point_list,
         | 
| 55 | 
            +
            		int W, int H,
         | 
| 56 | 
            +
            		const float2* points_xy_image,
         | 
| 57 | 
            +
            		const float* features,
         | 
| 58 | 
            +
            		const float4* conic_opacity,
         | 
| 59 | 
            +
            		float* final_T,
         | 
| 60 | 
            +
            		uint32_t* n_contrib,
         | 
| 61 | 
            +
            		const float* bg_color,
         | 
| 62 | 
            +
            		float* out_color);
         | 
| 63 | 
            +
            }
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            #endif
         | 
    	
        submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer.h
    ADDED
    
    | @@ -0,0 +1,88 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            /*
         | 
| 2 | 
            +
             * Copyright (C) 2023, Inria
         | 
| 3 | 
            +
             * GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
             * All rights reserved.
         | 
| 5 | 
            +
             *
         | 
| 6 | 
            +
             * This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
             * under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
             *
         | 
| 9 | 
            +
             * For inquiries contact  [email protected]
         | 
| 10 | 
            +
             */
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            #ifndef CUDA_RASTERIZER_H_INCLUDED
         | 
| 13 | 
            +
            #define CUDA_RASTERIZER_H_INCLUDED
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            #include <vector>
         | 
| 16 | 
            +
            #include <functional>
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            namespace CudaRasterizer
         | 
| 19 | 
            +
            {
         | 
| 20 | 
            +
            	class Rasterizer
         | 
| 21 | 
            +
            	{
         | 
| 22 | 
            +
            	public:
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            		static void markVisible(
         | 
| 25 | 
            +
            			int P,
         | 
| 26 | 
            +
            			float* means3D,
         | 
| 27 | 
            +
            			float* viewmatrix,
         | 
| 28 | 
            +
            			float* projmatrix,
         | 
| 29 | 
            +
            			bool* present);
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            		static int forward(
         | 
| 32 | 
            +
            			std::function<char* (size_t)> geometryBuffer,
         | 
| 33 | 
            +
            			std::function<char* (size_t)> binningBuffer,
         | 
| 34 | 
            +
            			std::function<char* (size_t)> imageBuffer,
         | 
| 35 | 
            +
            			const int P, int D, int M,
         | 
| 36 | 
            +
            			const float* background,
         | 
| 37 | 
            +
            			const int width, int height,
         | 
| 38 | 
            +
            			const float* means3D,
         | 
| 39 | 
            +
            			const float* shs,
         | 
| 40 | 
            +
            			const float* colors_precomp,
         | 
| 41 | 
            +
            			const float* opacities,
         | 
| 42 | 
            +
            			const float* scales,
         | 
| 43 | 
            +
            			const float scale_modifier,
         | 
| 44 | 
            +
            			const float* rotations,
         | 
| 45 | 
            +
            			const float* cov3D_precomp,
         | 
| 46 | 
            +
            			const float* viewmatrix,
         | 
| 47 | 
            +
            			const float* projmatrix,
         | 
| 48 | 
            +
            			const float* cam_pos,
         | 
| 49 | 
            +
            			const float tan_fovx, float tan_fovy,
         | 
| 50 | 
            +
            			const bool prefiltered,
         | 
| 51 | 
            +
            			float* out_color,
         | 
| 52 | 
            +
            			int* radii = nullptr,
         | 
| 53 | 
            +
            			bool debug = false);
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            		static void backward(
         | 
| 56 | 
            +
            			const int P, int D, int M, int R,
         | 
| 57 | 
            +
            			const float* background,
         | 
| 58 | 
            +
            			const int width, int height,
         | 
| 59 | 
            +
            			const float* means3D,
         | 
| 60 | 
            +
            			const float* shs,
         | 
| 61 | 
            +
            			const float* colors_precomp,
         | 
| 62 | 
            +
            			const float* scales,
         | 
| 63 | 
            +
            			const float scale_modifier,
         | 
| 64 | 
            +
            			const float* rotations,
         | 
| 65 | 
            +
            			const float* cov3D_precomp,
         | 
| 66 | 
            +
            			const float* viewmatrix,
         | 
| 67 | 
            +
            			const float* projmatrix,
         | 
| 68 | 
            +
            			const float* campos,
         | 
| 69 | 
            +
            			const float tan_fovx, float tan_fovy,
         | 
| 70 | 
            +
            			const int* radii,
         | 
| 71 | 
            +
            			char* geom_buffer,
         | 
| 72 | 
            +
            			char* binning_buffer,
         | 
| 73 | 
            +
            			char* image_buffer,
         | 
| 74 | 
            +
            			const float* dL_dpix,
         | 
| 75 | 
            +
            			float* dL_dmean2D,
         | 
| 76 | 
            +
            			float* dL_dconic,
         | 
| 77 | 
            +
            			float* dL_dopacity,
         | 
| 78 | 
            +
            			float* dL_dcolor,
         | 
| 79 | 
            +
            			float* dL_dmean3D,
         | 
| 80 | 
            +
            			float* dL_dcov3D,
         | 
| 81 | 
            +
            			float* dL_dsh,
         | 
| 82 | 
            +
            			float* dL_dscale,
         | 
| 83 | 
            +
            			float* dL_drot,
         | 
| 84 | 
            +
            			bool debug);
         | 
| 85 | 
            +
            	};
         | 
| 86 | 
            +
            };
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            #endif
         | 
    	
        submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.cu
    ADDED
    
    | @@ -0,0 +1,434 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            /*
         | 
| 2 | 
            +
             * Copyright (C) 2023, Inria
         | 
| 3 | 
            +
             * GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
             * All rights reserved.
         | 
| 5 | 
            +
             *
         | 
| 6 | 
            +
             * This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
             * under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
             *
         | 
| 9 | 
            +
             * For inquiries contact  [email protected]
         | 
| 10 | 
            +
             */
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            #include "rasterizer_impl.h"
         | 
| 13 | 
            +
            #include <iostream>
         | 
| 14 | 
            +
            #include <fstream>
         | 
| 15 | 
            +
            #include <algorithm>
         | 
| 16 | 
            +
            #include <numeric>
         | 
| 17 | 
            +
            #include <cuda.h>
         | 
| 18 | 
            +
            #include "cuda_runtime.h"
         | 
| 19 | 
            +
            #include "device_launch_parameters.h"
         | 
| 20 | 
            +
            #include <cub/cub.cuh>
         | 
| 21 | 
            +
            #include <cub/device/device_radix_sort.cuh>
         | 
| 22 | 
            +
            #define GLM_FORCE_CUDA
         | 
| 23 | 
            +
            #include <glm/glm.hpp>
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            #include <cooperative_groups.h>
         | 
| 26 | 
            +
            #include <cooperative_groups/reduce.h>
         | 
| 27 | 
            +
            namespace cg = cooperative_groups;
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            #include "auxiliary.h"
         | 
| 30 | 
            +
            #include "forward.h"
         | 
| 31 | 
            +
            #include "backward.h"
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            // Helper function to find the next-highest bit of the MSB
         | 
| 34 | 
            +
            // on the CPU.
         | 
| 35 | 
            +
            uint32_t getHigherMsb(uint32_t n)
         | 
| 36 | 
            +
            {
         | 
| 37 | 
            +
            	uint32_t msb = sizeof(n) * 4;
         | 
| 38 | 
            +
            	uint32_t step = msb;
         | 
| 39 | 
            +
            	while (step > 1)
         | 
| 40 | 
            +
            	{
         | 
| 41 | 
            +
            		step /= 2;
         | 
| 42 | 
            +
            		if (n >> msb)
         | 
| 43 | 
            +
            			msb += step;
         | 
| 44 | 
            +
            		else
         | 
| 45 | 
            +
            			msb -= step;
         | 
| 46 | 
            +
            	}
         | 
| 47 | 
            +
            	if (n >> msb)
         | 
| 48 | 
            +
            		msb++;
         | 
| 49 | 
            +
            	return msb;
         | 
| 50 | 
            +
            }
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            // Wrapper method to call auxiliary coarse frustum containment test.
         | 
| 53 | 
            +
            // Mark all Gaussians that pass it.
         | 
| 54 | 
            +
            __global__ void checkFrustum(int P,
         | 
| 55 | 
            +
            	const float* orig_points,
         | 
| 56 | 
            +
            	const float* viewmatrix,
         | 
| 57 | 
            +
            	const float* projmatrix,
         | 
| 58 | 
            +
            	bool* present)
         | 
| 59 | 
            +
            {
         | 
| 60 | 
            +
            	auto idx = cg::this_grid().thread_rank();
         | 
| 61 | 
            +
            	if (idx >= P)
         | 
| 62 | 
            +
            		return;
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            	float3 p_view;
         | 
| 65 | 
            +
            	present[idx] = in_frustum(idx, orig_points, viewmatrix, projmatrix, false, p_view);
         | 
| 66 | 
            +
            }
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            // Generates one key/value pair for all Gaussian / tile overlaps. 
         | 
| 69 | 
            +
            // Run once per Gaussian (1:N mapping).
         | 
| 70 | 
            +
            __global__ void duplicateWithKeys(
         | 
| 71 | 
            +
            	int P,
         | 
| 72 | 
            +
            	const float2* points_xy,
         | 
| 73 | 
            +
            	const float* depths,
         | 
| 74 | 
            +
            	const uint32_t* offsets,
         | 
| 75 | 
            +
            	uint64_t* gaussian_keys_unsorted,
         | 
| 76 | 
            +
            	uint32_t* gaussian_values_unsorted,
         | 
| 77 | 
            +
            	int* radii,
         | 
| 78 | 
            +
            	dim3 grid)
         | 
| 79 | 
            +
            {
         | 
| 80 | 
            +
            	auto idx = cg::this_grid().thread_rank();
         | 
| 81 | 
            +
            	if (idx >= P)
         | 
| 82 | 
            +
            		return;
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            	// Generate no key/value pair for invisible Gaussians
         | 
| 85 | 
            +
            	if (radii[idx] > 0)
         | 
| 86 | 
            +
            	{
         | 
| 87 | 
            +
            		// Find this Gaussian's offset in buffer for writing keys/values.
         | 
| 88 | 
            +
            		uint32_t off = (idx == 0) ? 0 : offsets[idx - 1];
         | 
| 89 | 
            +
            		uint2 rect_min, rect_max;
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            		getRect(points_xy[idx], radii[idx], rect_min, rect_max, grid);
         | 
| 92 | 
            +
             | 
| 93 | 
            +
            		// For each tile that the bounding rect overlaps, emit a 
         | 
| 94 | 
            +
            		// key/value pair. The key is |  tile ID  |      depth      |,
         | 
| 95 | 
            +
            		// and the value is the ID of the Gaussian. Sorting the values 
         | 
| 96 | 
            +
            		// with this key yields Gaussian IDs in a list, such that they
         | 
| 97 | 
            +
            		// are first sorted by tile and then by depth. 
         | 
| 98 | 
            +
            		for (int y = rect_min.y; y < rect_max.y; y++)
         | 
| 99 | 
            +
            		{
         | 
| 100 | 
            +
            			for (int x = rect_min.x; x < rect_max.x; x++)
         | 
| 101 | 
            +
            			{
         | 
| 102 | 
            +
            				uint64_t key = y * grid.x + x;
         | 
| 103 | 
            +
            				key <<= 32;
         | 
| 104 | 
            +
            				key |= *((uint32_t*)&depths[idx]);
         | 
| 105 | 
            +
            				gaussian_keys_unsorted[off] = key;
         | 
| 106 | 
            +
            				gaussian_values_unsorted[off] = idx;
         | 
| 107 | 
            +
            				off++;
         | 
| 108 | 
            +
            			}
         | 
| 109 | 
            +
            		}
         | 
| 110 | 
            +
            	}
         | 
| 111 | 
            +
            }
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            // Check keys to see if it is at the start/end of one tile's range in 
         | 
| 114 | 
            +
            // the full sorted list. If yes, write start/end of this tile. 
         | 
| 115 | 
            +
            // Run once per instanced (duplicated) Gaussian ID.
         | 
| 116 | 
            +
            __global__ void identifyTileRanges(int L, uint64_t* point_list_keys, uint2* ranges)
         | 
| 117 | 
            +
            {
         | 
| 118 | 
            +
            	auto idx = cg::this_grid().thread_rank();
         | 
| 119 | 
            +
            	if (idx >= L)
         | 
| 120 | 
            +
            		return;
         | 
| 121 | 
            +
             | 
| 122 | 
            +
            	// Read tile ID from key. Update start/end of tile range if at limit.
         | 
| 123 | 
            +
            	uint64_t key = point_list_keys[idx];
         | 
| 124 | 
            +
            	uint32_t currtile = key >> 32;
         | 
| 125 | 
            +
            	if (idx == 0)
         | 
| 126 | 
            +
            		ranges[currtile].x = 0;
         | 
| 127 | 
            +
            	else
         | 
| 128 | 
            +
            	{
         | 
| 129 | 
            +
            		uint32_t prevtile = point_list_keys[idx - 1] >> 32;
         | 
| 130 | 
            +
            		if (currtile != prevtile)
         | 
| 131 | 
            +
            		{
         | 
| 132 | 
            +
            			ranges[prevtile].y = idx;
         | 
| 133 | 
            +
            			ranges[currtile].x = idx;
         | 
| 134 | 
            +
            		}
         | 
| 135 | 
            +
            	}
         | 
| 136 | 
            +
            	if (idx == L - 1)
         | 
| 137 | 
            +
            		ranges[currtile].y = L;
         | 
| 138 | 
            +
            }
         | 
| 139 | 
            +
             | 
| 140 | 
            +
            // Mark Gaussians as visible/invisible, based on view frustum testing
         | 
| 141 | 
            +
            void CudaRasterizer::Rasterizer::markVisible(
         | 
| 142 | 
            +
            	int P,
         | 
| 143 | 
            +
            	float* means3D,
         | 
| 144 | 
            +
            	float* viewmatrix,
         | 
| 145 | 
            +
            	float* projmatrix,
         | 
| 146 | 
            +
            	bool* present)
         | 
| 147 | 
            +
            {
         | 
| 148 | 
            +
            	checkFrustum << <(P + 255) / 256, 256 >> > (
         | 
| 149 | 
            +
            		P,
         | 
| 150 | 
            +
            		means3D,
         | 
| 151 | 
            +
            		viewmatrix, projmatrix,
         | 
| 152 | 
            +
            		present);
         | 
| 153 | 
            +
            }
         | 
| 154 | 
            +
             | 
| 155 | 
            +
            CudaRasterizer::GeometryState CudaRasterizer::GeometryState::fromChunk(char*& chunk, size_t P)
         | 
| 156 | 
            +
            {
         | 
| 157 | 
            +
            	GeometryState geom;
         | 
| 158 | 
            +
            	obtain(chunk, geom.depths, P, 128);
         | 
| 159 | 
            +
            	obtain(chunk, geom.clamped, P * 3, 128);
         | 
| 160 | 
            +
            	obtain(chunk, geom.internal_radii, P, 128);
         | 
| 161 | 
            +
            	obtain(chunk, geom.means2D, P, 128);
         | 
| 162 | 
            +
            	obtain(chunk, geom.cov3D, P * 6, 128);
         | 
| 163 | 
            +
            	obtain(chunk, geom.conic_opacity, P, 128);
         | 
| 164 | 
            +
            	obtain(chunk, geom.rgb, P * 3, 128);
         | 
| 165 | 
            +
            	obtain(chunk, geom.tiles_touched, P, 128);
         | 
| 166 | 
            +
            	cub::DeviceScan::InclusiveSum(nullptr, geom.scan_size, geom.tiles_touched, geom.tiles_touched, P);
         | 
| 167 | 
            +
            	obtain(chunk, geom.scanning_space, geom.scan_size, 128);
         | 
| 168 | 
            +
            	obtain(chunk, geom.point_offsets, P, 128);
         | 
| 169 | 
            +
            	return geom;
         | 
| 170 | 
            +
            }
         | 
| 171 | 
            +
             | 
| 172 | 
            +
            CudaRasterizer::ImageState CudaRasterizer::ImageState::fromChunk(char*& chunk, size_t N)
         | 
| 173 | 
            +
            {
         | 
| 174 | 
            +
            	ImageState img;
         | 
| 175 | 
            +
            	obtain(chunk, img.accum_alpha, N, 128);
         | 
| 176 | 
            +
            	obtain(chunk, img.n_contrib, N, 128);
         | 
| 177 | 
            +
            	obtain(chunk, img.ranges, N, 128);
         | 
| 178 | 
            +
            	return img;
         | 
| 179 | 
            +
            }
         | 
| 180 | 
            +
             | 
| 181 | 
            +
            CudaRasterizer::BinningState CudaRasterizer::BinningState::fromChunk(char*& chunk, size_t P)
         | 
| 182 | 
            +
            {
         | 
| 183 | 
            +
            	BinningState binning;
         | 
| 184 | 
            +
            	obtain(chunk, binning.point_list, P, 128);
         | 
| 185 | 
            +
            	obtain(chunk, binning.point_list_unsorted, P, 128);
         | 
| 186 | 
            +
            	obtain(chunk, binning.point_list_keys, P, 128);
         | 
| 187 | 
            +
            	obtain(chunk, binning.point_list_keys_unsorted, P, 128);
         | 
| 188 | 
            +
            	cub::DeviceRadixSort::SortPairs(
         | 
| 189 | 
            +
            		nullptr, binning.sorting_size,
         | 
| 190 | 
            +
            		binning.point_list_keys_unsorted, binning.point_list_keys,
         | 
| 191 | 
            +
            		binning.point_list_unsorted, binning.point_list, P);
         | 
| 192 | 
            +
            	obtain(chunk, binning.list_sorting_space, binning.sorting_size, 128);
         | 
| 193 | 
            +
            	return binning;
         | 
| 194 | 
            +
            }
         | 
| 195 | 
            +
             | 
| 196 | 
            +
            // Forward rendering procedure for differentiable rasterization
         | 
| 197 | 
            +
            // of Gaussians.
         | 
| 198 | 
            +
            int CudaRasterizer::Rasterizer::forward(
         | 
| 199 | 
            +
            	std::function<char* (size_t)> geometryBuffer,
         | 
| 200 | 
            +
            	std::function<char* (size_t)> binningBuffer,
         | 
| 201 | 
            +
            	std::function<char* (size_t)> imageBuffer,
         | 
| 202 | 
            +
            	const int P, int D, int M,
         | 
| 203 | 
            +
            	const float* background,
         | 
| 204 | 
            +
            	const int width, int height,
         | 
| 205 | 
            +
            	const float* means3D,
         | 
| 206 | 
            +
            	const float* shs,
         | 
| 207 | 
            +
            	const float* colors_precomp,
         | 
| 208 | 
            +
            	const float* opacities,
         | 
| 209 | 
            +
            	const float* scales,
         | 
| 210 | 
            +
            	const float scale_modifier,
         | 
| 211 | 
            +
            	const float* rotations,
         | 
| 212 | 
            +
            	const float* cov3D_precomp,
         | 
| 213 | 
            +
            	const float* viewmatrix,
         | 
| 214 | 
            +
            	const float* projmatrix,
         | 
| 215 | 
            +
            	const float* cam_pos,
         | 
| 216 | 
            +
            	const float tan_fovx, float tan_fovy,
         | 
| 217 | 
            +
            	const bool prefiltered,
         | 
| 218 | 
            +
            	float* out_color,
         | 
| 219 | 
            +
            	int* radii,
         | 
| 220 | 
            +
            	bool debug)
         | 
| 221 | 
            +
            {
         | 
| 222 | 
            +
            	const float focal_y = height / (2.0f * tan_fovy);
         | 
| 223 | 
            +
            	const float focal_x = width / (2.0f * tan_fovx);
         | 
| 224 | 
            +
             | 
| 225 | 
            +
            	size_t chunk_size = required<GeometryState>(P);
         | 
| 226 | 
            +
            	char* chunkptr = geometryBuffer(chunk_size);
         | 
| 227 | 
            +
            	GeometryState geomState = GeometryState::fromChunk(chunkptr, P);
         | 
| 228 | 
            +
             | 
| 229 | 
            +
            	if (radii == nullptr)
         | 
| 230 | 
            +
            	{
         | 
| 231 | 
            +
            		radii = geomState.internal_radii;
         | 
| 232 | 
            +
            	}
         | 
| 233 | 
            +
             | 
| 234 | 
            +
            	dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
         | 
| 235 | 
            +
            	dim3 block(BLOCK_X, BLOCK_Y, 1);
         | 
| 236 | 
            +
             | 
| 237 | 
            +
            	// Dynamically resize image-based auxiliary buffers during training
         | 
| 238 | 
            +
            	size_t img_chunk_size = required<ImageState>(width * height);
         | 
| 239 | 
            +
            	char* img_chunkptr = imageBuffer(img_chunk_size);
         | 
| 240 | 
            +
            	ImageState imgState = ImageState::fromChunk(img_chunkptr, width * height);
         | 
| 241 | 
            +
             | 
| 242 | 
            +
            	if (NUM_CHANNELS != 3 && colors_precomp == nullptr)
         | 
| 243 | 
            +
            	{
         | 
| 244 | 
            +
            		throw std::runtime_error("For non-RGB, provide precomputed Gaussian colors!");
         | 
| 245 | 
            +
            	}
         | 
| 246 | 
            +
             | 
| 247 | 
            +
            	// Run preprocessing per-Gaussian (transformation, bounding, conversion of SHs to RGB)
         | 
| 248 | 
            +
            	CHECK_CUDA(FORWARD::preprocess(
         | 
| 249 | 
            +
            		P, D, M,
         | 
| 250 | 
            +
            		means3D,
         | 
| 251 | 
            +
            		(glm::vec3*)scales,
         | 
| 252 | 
            +
            		scale_modifier,
         | 
| 253 | 
            +
            		(glm::vec4*)rotations,
         | 
| 254 | 
            +
            		opacities,
         | 
| 255 | 
            +
            		shs,
         | 
| 256 | 
            +
            		geomState.clamped,
         | 
| 257 | 
            +
            		cov3D_precomp,
         | 
| 258 | 
            +
            		colors_precomp,
         | 
| 259 | 
            +
            		viewmatrix, projmatrix,
         | 
| 260 | 
            +
            		(glm::vec3*)cam_pos,
         | 
| 261 | 
            +
            		width, height,
         | 
| 262 | 
            +
            		focal_x, focal_y,
         | 
| 263 | 
            +
            		tan_fovx, tan_fovy,
         | 
| 264 | 
            +
            		radii,
         | 
| 265 | 
            +
            		geomState.means2D,
         | 
| 266 | 
            +
            		geomState.depths,
         | 
| 267 | 
            +
            		geomState.cov3D,
         | 
| 268 | 
            +
            		geomState.rgb,
         | 
| 269 | 
            +
            		geomState.conic_opacity,
         | 
| 270 | 
            +
            		tile_grid,
         | 
| 271 | 
            +
            		geomState.tiles_touched,
         | 
| 272 | 
            +
            		prefiltered
         | 
| 273 | 
            +
            	), debug)
         | 
| 274 | 
            +
             | 
| 275 | 
            +
            	// Compute prefix sum over full list of touched tile counts by Gaussians
         | 
| 276 | 
            +
            	// E.g., [2, 3, 0, 2, 1] -> [2, 5, 5, 7, 8]
         | 
| 277 | 
            +
            	CHECK_CUDA(cub::DeviceScan::InclusiveSum(geomState.scanning_space, geomState.scan_size, geomState.tiles_touched, geomState.point_offsets, P), debug)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
            	// Retrieve total number of Gaussian instances to launch and resize aux buffers
         | 
| 280 | 
            +
            	int num_rendered;
         | 
| 281 | 
            +
            	CHECK_CUDA(cudaMemcpy(&num_rendered, geomState.point_offsets + P - 1, sizeof(int), cudaMemcpyDeviceToHost), debug);
         | 
| 282 | 
            +
             | 
| 283 | 
            +
            	size_t binning_chunk_size = required<BinningState>(num_rendered);
         | 
| 284 | 
            +
            	char* binning_chunkptr = binningBuffer(binning_chunk_size);
         | 
| 285 | 
            +
            	BinningState binningState = BinningState::fromChunk(binning_chunkptr, num_rendered);
         | 
| 286 | 
            +
             | 
| 287 | 
            +
            	// For each instance to be rendered, produce adequate [ tile | depth ] key 
         | 
| 288 | 
            +
            	// and corresponding dublicated Gaussian indices to be sorted
         | 
| 289 | 
            +
            	duplicateWithKeys << <(P + 255) / 256, 256 >> > (
         | 
| 290 | 
            +
            		P,
         | 
| 291 | 
            +
            		geomState.means2D,
         | 
| 292 | 
            +
            		geomState.depths,
         | 
| 293 | 
            +
            		geomState.point_offsets,
         | 
| 294 | 
            +
            		binningState.point_list_keys_unsorted,
         | 
| 295 | 
            +
            		binningState.point_list_unsorted,
         | 
| 296 | 
            +
            		radii,
         | 
| 297 | 
            +
            		tile_grid)
         | 
| 298 | 
            +
            	CHECK_CUDA(, debug)
         | 
| 299 | 
            +
             | 
| 300 | 
            +
            	int bit = getHigherMsb(tile_grid.x * tile_grid.y);
         | 
| 301 | 
            +
             | 
| 302 | 
            +
            	// Sort complete list of (duplicated) Gaussian indices by keys
         | 
| 303 | 
            +
            	CHECK_CUDA(cub::DeviceRadixSort::SortPairs(
         | 
| 304 | 
            +
            		binningState.list_sorting_space,
         | 
| 305 | 
            +
            		binningState.sorting_size,
         | 
| 306 | 
            +
            		binningState.point_list_keys_unsorted, binningState.point_list_keys,
         | 
| 307 | 
            +
            		binningState.point_list_unsorted, binningState.point_list,
         | 
| 308 | 
            +
            		num_rendered, 0, 32 + bit), debug)
         | 
| 309 | 
            +
             | 
| 310 | 
            +
            	CHECK_CUDA(cudaMemset(imgState.ranges, 0, tile_grid.x * tile_grid.y * sizeof(uint2)), debug);
         | 
| 311 | 
            +
             | 
| 312 | 
            +
            	// Identify start and end of per-tile workloads in sorted list
         | 
| 313 | 
            +
            	if (num_rendered > 0)
         | 
| 314 | 
            +
            		identifyTileRanges << <(num_rendered + 255) / 256, 256 >> > (
         | 
| 315 | 
            +
            			num_rendered,
         | 
| 316 | 
            +
            			binningState.point_list_keys,
         | 
| 317 | 
            +
            			imgState.ranges);
         | 
| 318 | 
            +
            	CHECK_CUDA(, debug)
         | 
| 319 | 
            +
             | 
| 320 | 
            +
            	// Let each tile blend its range of Gaussians independently in parallel
         | 
| 321 | 
            +
            	const float* feature_ptr = colors_precomp != nullptr ? colors_precomp : geomState.rgb;
         | 
| 322 | 
            +
            	CHECK_CUDA(FORWARD::render(
         | 
| 323 | 
            +
            		tile_grid, block,
         | 
| 324 | 
            +
            		imgState.ranges,
         | 
| 325 | 
            +
            		binningState.point_list,
         | 
| 326 | 
            +
            		width, height,
         | 
| 327 | 
            +
            		geomState.means2D,
         | 
| 328 | 
            +
            		feature_ptr,
         | 
| 329 | 
            +
            		geomState.conic_opacity,
         | 
| 330 | 
            +
            		imgState.accum_alpha,
         | 
| 331 | 
            +
            		imgState.n_contrib,
         | 
| 332 | 
            +
            		background,
         | 
| 333 | 
            +
            		out_color), debug)
         | 
| 334 | 
            +
             | 
| 335 | 
            +
            	return num_rendered;
         | 
| 336 | 
            +
            }
         | 
| 337 | 
            +
             | 
| 338 | 
            +
            // Produce necessary gradients for optimization, corresponding
         | 
| 339 | 
            +
            // to forward render pass
         | 
| 340 | 
            +
            void CudaRasterizer::Rasterizer::backward(
         | 
| 341 | 
            +
            	const int P, int D, int M, int R,
         | 
| 342 | 
            +
            	const float* background,
         | 
| 343 | 
            +
            	const int width, int height,
         | 
| 344 | 
            +
            	const float* means3D,
         | 
| 345 | 
            +
            	const float* shs,
         | 
| 346 | 
            +
            	const float* colors_precomp,
         | 
| 347 | 
            +
            	const float* scales,
         | 
| 348 | 
            +
            	const float scale_modifier,
         | 
| 349 | 
            +
            	const float* rotations,
         | 
| 350 | 
            +
            	const float* cov3D_precomp,
         | 
| 351 | 
            +
            	const float* viewmatrix,
         | 
| 352 | 
            +
            	const float* projmatrix,
         | 
| 353 | 
            +
            	const float* campos,
         | 
| 354 | 
            +
            	const float tan_fovx, float tan_fovy,
         | 
| 355 | 
            +
            	const int* radii,
         | 
| 356 | 
            +
            	char* geom_buffer,
         | 
| 357 | 
            +
            	char* binning_buffer,
         | 
| 358 | 
            +
            	char* img_buffer,
         | 
| 359 | 
            +
            	const float* dL_dpix,
         | 
| 360 | 
            +
            	float* dL_dmean2D,
         | 
| 361 | 
            +
            	float* dL_dconic,
         | 
| 362 | 
            +
            	float* dL_dopacity,
         | 
| 363 | 
            +
            	float* dL_dcolor,
         | 
| 364 | 
            +
            	float* dL_dmean3D,
         | 
| 365 | 
            +
            	float* dL_dcov3D,
         | 
| 366 | 
            +
            	float* dL_dsh,
         | 
| 367 | 
            +
            	float* dL_dscale,
         | 
| 368 | 
            +
            	float* dL_drot,
         | 
| 369 | 
            +
            	bool debug)
         | 
| 370 | 
            +
            {
         | 
| 371 | 
            +
            	GeometryState geomState = GeometryState::fromChunk(geom_buffer, P);
         | 
| 372 | 
            +
            	BinningState binningState = BinningState::fromChunk(binning_buffer, R);
         | 
| 373 | 
            +
            	ImageState imgState = ImageState::fromChunk(img_buffer, width * height);
         | 
| 374 | 
            +
             | 
| 375 | 
            +
            	if (radii == nullptr)
         | 
| 376 | 
            +
            	{
         | 
| 377 | 
            +
            		radii = geomState.internal_radii;
         | 
| 378 | 
            +
            	}
         | 
| 379 | 
            +
             | 
| 380 | 
            +
            	const float focal_y = height / (2.0f * tan_fovy);
         | 
| 381 | 
            +
            	const float focal_x = width / (2.0f * tan_fovx);
         | 
| 382 | 
            +
             | 
| 383 | 
            +
            	const dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
         | 
| 384 | 
            +
            	const dim3 block(BLOCK_X, BLOCK_Y, 1);
         | 
| 385 | 
            +
             | 
| 386 | 
            +
            	// Compute loss gradients w.r.t. 2D mean position, conic matrix,
         | 
| 387 | 
            +
            	// opacity and RGB of Gaussians from per-pixel loss gradients.
         | 
| 388 | 
            +
            	// If we were given precomputed colors and not SHs, use them.
         | 
| 389 | 
            +
            	const float* color_ptr = (colors_precomp != nullptr) ? colors_precomp : geomState.rgb;
         | 
| 390 | 
            +
            	CHECK_CUDA(BACKWARD::render(
         | 
| 391 | 
            +
            		tile_grid,
         | 
| 392 | 
            +
            		block,
         | 
| 393 | 
            +
            		imgState.ranges,
         | 
| 394 | 
            +
            		binningState.point_list,
         | 
| 395 | 
            +
            		width, height,
         | 
| 396 | 
            +
            		background,
         | 
| 397 | 
            +
            		geomState.means2D,
         | 
| 398 | 
            +
            		geomState.conic_opacity,
         | 
| 399 | 
            +
            		color_ptr,
         | 
| 400 | 
            +
            		imgState.accum_alpha,
         | 
| 401 | 
            +
            		imgState.n_contrib,
         | 
| 402 | 
            +
            		dL_dpix,
         | 
| 403 | 
            +
            		(float3*)dL_dmean2D,
         | 
| 404 | 
            +
            		(float4*)dL_dconic,
         | 
| 405 | 
            +
            		dL_dopacity,
         | 
| 406 | 
            +
            		dL_dcolor), debug)
         | 
| 407 | 
            +
             | 
| 408 | 
            +
            	// Take care of the rest of preprocessing. Was the precomputed covariance
         | 
| 409 | 
            +
            	// given to us or a scales/rot pair? If precomputed, pass that. If not,
         | 
| 410 | 
            +
            	// use the one we computed ourselves.
         | 
| 411 | 
            +
            	const float* cov3D_ptr = (cov3D_precomp != nullptr) ? cov3D_precomp : geomState.cov3D;
         | 
| 412 | 
            +
            	CHECK_CUDA(BACKWARD::preprocess(P, D, M,
         | 
| 413 | 
            +
            		(float3*)means3D,
         | 
| 414 | 
            +
            		radii,
         | 
| 415 | 
            +
            		shs,
         | 
| 416 | 
            +
            		geomState.clamped,
         | 
| 417 | 
            +
            		(glm::vec3*)scales,
         | 
| 418 | 
            +
            		(glm::vec4*)rotations,
         | 
| 419 | 
            +
            		scale_modifier,
         | 
| 420 | 
            +
            		cov3D_ptr,
         | 
| 421 | 
            +
            		viewmatrix,
         | 
| 422 | 
            +
            		projmatrix,
         | 
| 423 | 
            +
            		focal_x, focal_y,
         | 
| 424 | 
            +
            		tan_fovx, tan_fovy,
         | 
| 425 | 
            +
            		(glm::vec3*)campos,
         | 
| 426 | 
            +
            		(float3*)dL_dmean2D,
         | 
| 427 | 
            +
            		dL_dconic,
         | 
| 428 | 
            +
            		(glm::vec3*)dL_dmean3D,
         | 
| 429 | 
            +
            		dL_dcolor,
         | 
| 430 | 
            +
            		dL_dcov3D,
         | 
| 431 | 
            +
            		dL_dsh,
         | 
| 432 | 
            +
            		(glm::vec3*)dL_dscale,
         | 
| 433 | 
            +
            		(glm::vec4*)dL_drot), debug)
         | 
| 434 | 
            +
            }
         | 
    	
        submodules/diff-gaussian-rasterization/cuda_rasterizer/rasterizer_impl.h
    ADDED
    
    | @@ -0,0 +1,74 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            /*
         | 
| 2 | 
            +
             * Copyright (C) 2023, Inria
         | 
| 3 | 
            +
             * GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
             * All rights reserved.
         | 
| 5 | 
            +
             *
         | 
| 6 | 
            +
             * This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
             * under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
             *
         | 
| 9 | 
            +
             * For inquiries contact  [email protected]
         | 
| 10 | 
            +
             */
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            #pragma once
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            #include <iostream>
         | 
| 15 | 
            +
            #include <vector>
         | 
| 16 | 
            +
            #include "rasterizer.h"
         | 
| 17 | 
            +
            #include <cuda_runtime_api.h>
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            namespace CudaRasterizer
         | 
| 20 | 
            +
            {
         | 
| 21 | 
            +
            	template <typename T>
         | 
| 22 | 
            +
            	static void obtain(char*& chunk, T*& ptr, std::size_t count, std::size_t alignment)
         | 
| 23 | 
            +
            	{
         | 
| 24 | 
            +
            		std::size_t offset = (reinterpret_cast<std::uintptr_t>(chunk) + alignment - 1) & ~(alignment - 1);
         | 
| 25 | 
            +
            		ptr = reinterpret_cast<T*>(offset);
         | 
| 26 | 
            +
            		chunk = reinterpret_cast<char*>(ptr + count);
         | 
| 27 | 
            +
            	}
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            	struct GeometryState
         | 
| 30 | 
            +
            	{
         | 
| 31 | 
            +
            		size_t scan_size;
         | 
| 32 | 
            +
            		float* depths;
         | 
| 33 | 
            +
            		char* scanning_space;
         | 
| 34 | 
            +
            		bool* clamped;
         | 
| 35 | 
            +
            		int* internal_radii;
         | 
| 36 | 
            +
            		float2* means2D;
         | 
| 37 | 
            +
            		float* cov3D;
         | 
| 38 | 
            +
            		float4* conic_opacity;
         | 
| 39 | 
            +
            		float* rgb;
         | 
| 40 | 
            +
            		uint32_t* point_offsets;
         | 
| 41 | 
            +
            		uint32_t* tiles_touched;
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            		static GeometryState fromChunk(char*& chunk, size_t P);
         | 
| 44 | 
            +
            	};
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            	struct ImageState
         | 
| 47 | 
            +
            	{
         | 
| 48 | 
            +
            		uint2* ranges;
         | 
| 49 | 
            +
            		uint32_t* n_contrib;
         | 
| 50 | 
            +
            		float* accum_alpha;
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            		static ImageState fromChunk(char*& chunk, size_t N);
         | 
| 53 | 
            +
            	};
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            	struct BinningState
         | 
| 56 | 
            +
            	{
         | 
| 57 | 
            +
            		size_t sorting_size;
         | 
| 58 | 
            +
            		uint64_t* point_list_keys_unsorted;
         | 
| 59 | 
            +
            		uint64_t* point_list_keys;
         | 
| 60 | 
            +
            		uint32_t* point_list_unsorted;
         | 
| 61 | 
            +
            		uint32_t* point_list;
         | 
| 62 | 
            +
            		char* list_sorting_space;
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            		static BinningState fromChunk(char*& chunk, size_t P);
         | 
| 65 | 
            +
            	};
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            	template<typename T> 
         | 
| 68 | 
            +
            	size_t required(size_t P)
         | 
| 69 | 
            +
            	{
         | 
| 70 | 
            +
            		char* size = nullptr;
         | 
| 71 | 
            +
            		T::fromChunk(size, P);
         | 
| 72 | 
            +
            		return ((size_t)size) + 128;
         | 
| 73 | 
            +
            	}
         | 
| 74 | 
            +
            };
         | 
    	
        submodules/diff-gaussian-rasterization/diff_gaussian_rasterization/__init__.py
    ADDED
    
    | @@ -0,0 +1,221 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            # Copyright (C) 2023, Inria
         | 
| 3 | 
            +
            # GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
            # All rights reserved.
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            # This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
            # under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # For inquiries contact  [email protected]
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from typing import NamedTuple
         | 
| 13 | 
            +
            import torch.nn as nn
         | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            from . import _C
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            def cpu_deep_copy_tuple(input_tuple):
         | 
| 18 | 
            +
                copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple]
         | 
| 19 | 
            +
                return tuple(copied_tensors)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            def rasterize_gaussians(
         | 
| 22 | 
            +
                means3D,
         | 
| 23 | 
            +
                means2D,
         | 
| 24 | 
            +
                sh,
         | 
| 25 | 
            +
                colors_precomp,
         | 
| 26 | 
            +
                opacities,
         | 
| 27 | 
            +
                scales,
         | 
| 28 | 
            +
                rotations,
         | 
| 29 | 
            +
                cov3Ds_precomp,
         | 
| 30 | 
            +
                raster_settings,
         | 
| 31 | 
            +
            ):
         | 
| 32 | 
            +
                return _RasterizeGaussians.apply(
         | 
| 33 | 
            +
                    means3D,
         | 
| 34 | 
            +
                    means2D,
         | 
| 35 | 
            +
                    sh,
         | 
| 36 | 
            +
                    colors_precomp,
         | 
| 37 | 
            +
                    opacities,
         | 
| 38 | 
            +
                    scales,
         | 
| 39 | 
            +
                    rotations,
         | 
| 40 | 
            +
                    cov3Ds_precomp,
         | 
| 41 | 
            +
                    raster_settings,
         | 
| 42 | 
            +
                )
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            class _RasterizeGaussians(torch.autograd.Function):
         | 
| 45 | 
            +
                @staticmethod
         | 
| 46 | 
            +
                def forward(
         | 
| 47 | 
            +
                    ctx,
         | 
| 48 | 
            +
                    means3D,
         | 
| 49 | 
            +
                    means2D,
         | 
| 50 | 
            +
                    sh,
         | 
| 51 | 
            +
                    colors_precomp,
         | 
| 52 | 
            +
                    opacities,
         | 
| 53 | 
            +
                    scales,
         | 
| 54 | 
            +
                    rotations,
         | 
| 55 | 
            +
                    cov3Ds_precomp,
         | 
| 56 | 
            +
                    raster_settings,
         | 
| 57 | 
            +
                ):
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    # Restructure arguments the way that the C++ lib expects them
         | 
| 60 | 
            +
                    args = (
         | 
| 61 | 
            +
                        raster_settings.bg, 
         | 
| 62 | 
            +
                        means3D,
         | 
| 63 | 
            +
                        colors_precomp,
         | 
| 64 | 
            +
                        opacities,
         | 
| 65 | 
            +
                        scales,
         | 
| 66 | 
            +
                        rotations,
         | 
| 67 | 
            +
                        raster_settings.scale_modifier,
         | 
| 68 | 
            +
                        cov3Ds_precomp,
         | 
| 69 | 
            +
                        raster_settings.viewmatrix,
         | 
| 70 | 
            +
                        raster_settings.projmatrix,
         | 
| 71 | 
            +
                        raster_settings.tanfovx,
         | 
| 72 | 
            +
                        raster_settings.tanfovy,
         | 
| 73 | 
            +
                        raster_settings.image_height,
         | 
| 74 | 
            +
                        raster_settings.image_width,
         | 
| 75 | 
            +
                        sh,
         | 
| 76 | 
            +
                        raster_settings.sh_degree,
         | 
| 77 | 
            +
                        raster_settings.campos,
         | 
| 78 | 
            +
                        raster_settings.prefiltered,
         | 
| 79 | 
            +
                        raster_settings.debug
         | 
| 80 | 
            +
                    )
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    # Invoke C++/CUDA rasterizer
         | 
| 83 | 
            +
                    if raster_settings.debug:
         | 
| 84 | 
            +
                        cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
         | 
| 85 | 
            +
                        try:
         | 
| 86 | 
            +
                            num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
         | 
| 87 | 
            +
                        except Exception as ex:
         | 
| 88 | 
            +
                            torch.save(cpu_args, "snapshot_fw.dump")
         | 
| 89 | 
            +
                            print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
         | 
| 90 | 
            +
                            raise ex
         | 
| 91 | 
            +
                    else:
         | 
| 92 | 
            +
                        num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    # Keep relevant tensors for backward
         | 
| 95 | 
            +
                    ctx.raster_settings = raster_settings
         | 
| 96 | 
            +
                    ctx.num_rendered = num_rendered
         | 
| 97 | 
            +
                    ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer)
         | 
| 98 | 
            +
                    return color, radii
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                @staticmethod
         | 
| 101 | 
            +
                def backward(ctx, grad_out_color, _):
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    # Restore necessary values from context
         | 
| 104 | 
            +
                    num_rendered = ctx.num_rendered
         | 
| 105 | 
            +
                    raster_settings = ctx.raster_settings
         | 
| 106 | 
            +
                    colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer = ctx.saved_tensors
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    # Restructure args as C++ method expects them
         | 
| 109 | 
            +
                    args = (raster_settings.bg,
         | 
| 110 | 
            +
                            means3D, 
         | 
| 111 | 
            +
                            radii, 
         | 
| 112 | 
            +
                            colors_precomp, 
         | 
| 113 | 
            +
                            scales, 
         | 
| 114 | 
            +
                            rotations, 
         | 
| 115 | 
            +
                            raster_settings.scale_modifier, 
         | 
| 116 | 
            +
                            cov3Ds_precomp, 
         | 
| 117 | 
            +
                            raster_settings.viewmatrix, 
         | 
| 118 | 
            +
                            raster_settings.projmatrix, 
         | 
| 119 | 
            +
                            raster_settings.tanfovx, 
         | 
| 120 | 
            +
                            raster_settings.tanfovy, 
         | 
| 121 | 
            +
                            grad_out_color, 
         | 
| 122 | 
            +
                            sh, 
         | 
| 123 | 
            +
                            raster_settings.sh_degree, 
         | 
| 124 | 
            +
                            raster_settings.campos,
         | 
| 125 | 
            +
                            geomBuffer,
         | 
| 126 | 
            +
                            num_rendered,
         | 
| 127 | 
            +
                            binningBuffer,
         | 
| 128 | 
            +
                            imgBuffer,
         | 
| 129 | 
            +
                            raster_settings.debug)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    # Compute gradients for relevant tensors by invoking backward method
         | 
| 132 | 
            +
                    if raster_settings.debug:
         | 
| 133 | 
            +
                        cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
         | 
| 134 | 
            +
                        try:
         | 
| 135 | 
            +
                            grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
         | 
| 136 | 
            +
                        except Exception as ex:
         | 
| 137 | 
            +
                            torch.save(cpu_args, "snapshot_bw.dump")
         | 
| 138 | 
            +
                            print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n")
         | 
| 139 | 
            +
                            raise ex
         | 
| 140 | 
            +
                    else:
         | 
| 141 | 
            +
                         grad_means2D, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    grads = (
         | 
| 144 | 
            +
                        grad_means3D,
         | 
| 145 | 
            +
                        grad_means2D,
         | 
| 146 | 
            +
                        grad_sh,
         | 
| 147 | 
            +
                        grad_colors_precomp,
         | 
| 148 | 
            +
                        grad_opacities,
         | 
| 149 | 
            +
                        grad_scales,
         | 
| 150 | 
            +
                        grad_rotations,
         | 
| 151 | 
            +
                        grad_cov3Ds_precomp,
         | 
| 152 | 
            +
                        None,
         | 
| 153 | 
            +
                    )
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    return grads
         | 
| 156 | 
            +
             | 
| 157 | 
            +
            class GaussianRasterizationSettings(NamedTuple):
         | 
| 158 | 
            +
                image_height: int
         | 
| 159 | 
            +
                image_width: int 
         | 
| 160 | 
            +
                tanfovx : float
         | 
| 161 | 
            +
                tanfovy : float
         | 
| 162 | 
            +
                bg : torch.Tensor
         | 
| 163 | 
            +
                scale_modifier : float
         | 
| 164 | 
            +
                viewmatrix : torch.Tensor
         | 
| 165 | 
            +
                projmatrix : torch.Tensor
         | 
| 166 | 
            +
                sh_degree : int
         | 
| 167 | 
            +
                campos : torch.Tensor
         | 
| 168 | 
            +
                prefiltered : bool
         | 
| 169 | 
            +
                debug : bool
         | 
| 170 | 
            +
             | 
| 171 | 
            +
            class GaussianRasterizer(nn.Module):
         | 
| 172 | 
            +
                def __init__(self, raster_settings):
         | 
| 173 | 
            +
                    super().__init__()
         | 
| 174 | 
            +
                    self.raster_settings = raster_settings
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                def markVisible(self, positions):
         | 
| 177 | 
            +
                    # Mark visible points (based on frustum culling for camera) with a boolean 
         | 
| 178 | 
            +
                    with torch.no_grad():
         | 
| 179 | 
            +
                        raster_settings = self.raster_settings
         | 
| 180 | 
            +
                        visible = _C.mark_visible(
         | 
| 181 | 
            +
                            positions,
         | 
| 182 | 
            +
                            raster_settings.viewmatrix,
         | 
| 183 | 
            +
                            raster_settings.projmatrix)
         | 
| 184 | 
            +
                        
         | 
| 185 | 
            +
                    return visible
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None):
         | 
| 188 | 
            +
                    
         | 
| 189 | 
            +
                    raster_settings = self.raster_settings
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
         | 
| 192 | 
            +
                        raise Exception('Please provide excatly one of either SHs or precomputed colors!')
         | 
| 193 | 
            +
                    
         | 
| 194 | 
            +
                    if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None):
         | 
| 195 | 
            +
                        raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!')
         | 
| 196 | 
            +
                    
         | 
| 197 | 
            +
                    if shs is None:
         | 
| 198 | 
            +
                        shs = torch.Tensor([])
         | 
| 199 | 
            +
                    if colors_precomp is None:
         | 
| 200 | 
            +
                        colors_precomp = torch.Tensor([])
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    if scales is None:
         | 
| 203 | 
            +
                        scales = torch.Tensor([])
         | 
| 204 | 
            +
                    if rotations is None:
         | 
| 205 | 
            +
                        rotations = torch.Tensor([])
         | 
| 206 | 
            +
                    if cov3D_precomp is None:
         | 
| 207 | 
            +
                        cov3D_precomp = torch.Tensor([])
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    # Invoke C++/CUDA rasterization routine
         | 
| 210 | 
            +
                    return rasterize_gaussians(
         | 
| 211 | 
            +
                        means3D,
         | 
| 212 | 
            +
                        means2D,
         | 
| 213 | 
            +
                        shs,
         | 
| 214 | 
            +
                        colors_precomp,
         | 
| 215 | 
            +
                        opacities,
         | 
| 216 | 
            +
                        scales, 
         | 
| 217 | 
            +
                        rotations,
         | 
| 218 | 
            +
                        cov3D_precomp,
         | 
| 219 | 
            +
                        raster_settings, 
         | 
| 220 | 
            +
                    )
         | 
| 221 | 
            +
             | 
    	
        submodules/diff-gaussian-rasterization/ext.cpp
    ADDED
    
    | @@ -0,0 +1,19 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            /*
         | 
| 2 | 
            +
             * Copyright (C) 2023, Inria
         | 
| 3 | 
            +
             * GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
             * All rights reserved.
         | 
| 5 | 
            +
             *
         | 
| 6 | 
            +
             * This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
             * under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
             *
         | 
| 9 | 
            +
             * For inquiries contact  [email protected]
         | 
| 10 | 
            +
             */
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            #include <torch/extension.h>
         | 
| 13 | 
            +
            #include "rasterize_points.h"
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
         | 
| 16 | 
            +
              m.def("rasterize_gaussians", &RasterizeGaussiansCUDA);
         | 
| 17 | 
            +
              m.def("rasterize_gaussians_backward", &RasterizeGaussiansBackwardCUDA);
         | 
| 18 | 
            +
              m.def("mark_visible", &markVisible);
         | 
| 19 | 
            +
            }
         | 
    	
        submodules/diff-gaussian-rasterization/rasterize_points.cu
    ADDED
    
    | @@ -0,0 +1,217 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            /*
         | 
| 2 | 
            +
             * Copyright (C) 2023, Inria
         | 
| 3 | 
            +
             * GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
             * All rights reserved.
         | 
| 5 | 
            +
             *
         | 
| 6 | 
            +
             * This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
             * under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
             *
         | 
| 9 | 
            +
             * For inquiries contact  [email protected]
         | 
| 10 | 
            +
             */
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            #include <math.h>
         | 
| 13 | 
            +
            #include <torch/extension.h>
         | 
| 14 | 
            +
            #include <cstdio>
         | 
| 15 | 
            +
            #include <sstream>
         | 
| 16 | 
            +
            #include <iostream>
         | 
| 17 | 
            +
            #include <tuple>
         | 
| 18 | 
            +
            #include <stdio.h>
         | 
| 19 | 
            +
            #include <cuda_runtime_api.h>
         | 
| 20 | 
            +
            #include <memory>
         | 
| 21 | 
            +
            #include "cuda_rasterizer/config.h"
         | 
| 22 | 
            +
            #include "cuda_rasterizer/rasterizer.h"
         | 
| 23 | 
            +
            #include <fstream>
         | 
| 24 | 
            +
            #include <string>
         | 
| 25 | 
            +
            #include <functional>
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            std::function<char*(size_t N)> resizeFunctional(torch::Tensor& t) {
         | 
| 28 | 
            +
                auto lambda = [&t](size_t N) {
         | 
| 29 | 
            +
                    t.resize_({(long long)N});
         | 
| 30 | 
            +
            		return reinterpret_cast<char*>(t.contiguous().data_ptr());
         | 
| 31 | 
            +
                };
         | 
| 32 | 
            +
                return lambda;
         | 
| 33 | 
            +
            }
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
         | 
| 36 | 
            +
            RasterizeGaussiansCUDA(
         | 
| 37 | 
            +
            	const torch::Tensor& background,
         | 
| 38 | 
            +
            	const torch::Tensor& means3D,
         | 
| 39 | 
            +
                const torch::Tensor& colors,
         | 
| 40 | 
            +
                const torch::Tensor& opacity,
         | 
| 41 | 
            +
            	const torch::Tensor& scales,
         | 
| 42 | 
            +
            	const torch::Tensor& rotations,
         | 
| 43 | 
            +
            	const float scale_modifier,
         | 
| 44 | 
            +
            	const torch::Tensor& cov3D_precomp,
         | 
| 45 | 
            +
            	const torch::Tensor& viewmatrix,
         | 
| 46 | 
            +
            	const torch::Tensor& projmatrix,
         | 
| 47 | 
            +
            	const float tan_fovx, 
         | 
| 48 | 
            +
            	const float tan_fovy,
         | 
| 49 | 
            +
                const int image_height,
         | 
| 50 | 
            +
                const int image_width,
         | 
| 51 | 
            +
            	const torch::Tensor& sh,
         | 
| 52 | 
            +
            	const int degree,
         | 
| 53 | 
            +
            	const torch::Tensor& campos,
         | 
| 54 | 
            +
            	const bool prefiltered,
         | 
| 55 | 
            +
            	const bool debug)
         | 
| 56 | 
            +
            {
         | 
| 57 | 
            +
              if (means3D.ndimension() != 2 || means3D.size(1) != 3) {
         | 
| 58 | 
            +
                AT_ERROR("means3D must have dimensions (num_points, 3)");
         | 
| 59 | 
            +
              }
         | 
| 60 | 
            +
              
         | 
| 61 | 
            +
              const int P = means3D.size(0);
         | 
| 62 | 
            +
              const int H = image_height;
         | 
| 63 | 
            +
              const int W = image_width;
         | 
| 64 | 
            +
             | 
| 65 | 
            +
              auto int_opts = means3D.options().dtype(torch::kInt32);
         | 
| 66 | 
            +
              auto float_opts = means3D.options().dtype(torch::kFloat32);
         | 
| 67 | 
            +
             | 
| 68 | 
            +
              torch::Tensor out_color = torch::full({NUM_CHANNELS, H, W}, 0.0, float_opts);
         | 
| 69 | 
            +
              torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32));
         | 
| 70 | 
            +
              
         | 
| 71 | 
            +
              torch::Device device(torch::kCUDA);
         | 
| 72 | 
            +
              torch::TensorOptions options(torch::kByte);
         | 
| 73 | 
            +
              torch::Tensor geomBuffer = torch::empty({0}, options.device(device));
         | 
| 74 | 
            +
              torch::Tensor binningBuffer = torch::empty({0}, options.device(device));
         | 
| 75 | 
            +
              torch::Tensor imgBuffer = torch::empty({0}, options.device(device));
         | 
| 76 | 
            +
              std::function<char*(size_t)> geomFunc = resizeFunctional(geomBuffer);
         | 
| 77 | 
            +
              std::function<char*(size_t)> binningFunc = resizeFunctional(binningBuffer);
         | 
| 78 | 
            +
              std::function<char*(size_t)> imgFunc = resizeFunctional(imgBuffer);
         | 
| 79 | 
            +
              
         | 
| 80 | 
            +
              int rendered = 0;
         | 
| 81 | 
            +
              if(P != 0)
         | 
| 82 | 
            +
              {
         | 
| 83 | 
            +
            	  int M = 0;
         | 
| 84 | 
            +
            	  if(sh.size(0) != 0)
         | 
| 85 | 
            +
            	  {
         | 
| 86 | 
            +
            		M = sh.size(1);
         | 
| 87 | 
            +
                  }
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            	  rendered = CudaRasterizer::Rasterizer::forward(
         | 
| 90 | 
            +
            	    geomFunc,
         | 
| 91 | 
            +
            		binningFunc,
         | 
| 92 | 
            +
            		imgFunc,
         | 
| 93 | 
            +
            	    P, degree, M,
         | 
| 94 | 
            +
            		background.contiguous().data<float>(),
         | 
| 95 | 
            +
            		W, H,
         | 
| 96 | 
            +
            		means3D.contiguous().data<float>(),
         | 
| 97 | 
            +
            		sh.contiguous().data_ptr<float>(),
         | 
| 98 | 
            +
            		colors.contiguous().data<float>(), 
         | 
| 99 | 
            +
            		opacity.contiguous().data<float>(), 
         | 
| 100 | 
            +
            		scales.contiguous().data_ptr<float>(),
         | 
| 101 | 
            +
            		scale_modifier,
         | 
| 102 | 
            +
            		rotations.contiguous().data_ptr<float>(),
         | 
| 103 | 
            +
            		cov3D_precomp.contiguous().data<float>(), 
         | 
| 104 | 
            +
            		viewmatrix.contiguous().data<float>(), 
         | 
| 105 | 
            +
            		projmatrix.contiguous().data<float>(),
         | 
| 106 | 
            +
            		campos.contiguous().data<float>(),
         | 
| 107 | 
            +
            		tan_fovx,
         | 
| 108 | 
            +
            		tan_fovy,
         | 
| 109 | 
            +
            		prefiltered,
         | 
| 110 | 
            +
            		out_color.contiguous().data<float>(),
         | 
| 111 | 
            +
            		radii.contiguous().data<int>(),
         | 
| 112 | 
            +
            		debug);
         | 
| 113 | 
            +
              }
         | 
| 114 | 
            +
              return std::make_tuple(rendered, out_color, radii, geomBuffer, binningBuffer, imgBuffer);
         | 
| 115 | 
            +
            }
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
         | 
| 118 | 
            +
             RasterizeGaussiansBackwardCUDA(
         | 
| 119 | 
            +
             	const torch::Tensor& background,
         | 
| 120 | 
            +
            	const torch::Tensor& means3D,
         | 
| 121 | 
            +
            	const torch::Tensor& radii,
         | 
| 122 | 
            +
                const torch::Tensor& colors,
         | 
| 123 | 
            +
            	const torch::Tensor& scales,
         | 
| 124 | 
            +
            	const torch::Tensor& rotations,
         | 
| 125 | 
            +
            	const float scale_modifier,
         | 
| 126 | 
            +
            	const torch::Tensor& cov3D_precomp,
         | 
| 127 | 
            +
            	const torch::Tensor& viewmatrix,
         | 
| 128 | 
            +
                const torch::Tensor& projmatrix,
         | 
| 129 | 
            +
            	const float tan_fovx,
         | 
| 130 | 
            +
            	const float tan_fovy,
         | 
| 131 | 
            +
                const torch::Tensor& dL_dout_color,
         | 
| 132 | 
            +
            	const torch::Tensor& sh,
         | 
| 133 | 
            +
            	const int degree,
         | 
| 134 | 
            +
            	const torch::Tensor& campos,
         | 
| 135 | 
            +
            	const torch::Tensor& geomBuffer,
         | 
| 136 | 
            +
            	const int R,
         | 
| 137 | 
            +
            	const torch::Tensor& binningBuffer,
         | 
| 138 | 
            +
            	const torch::Tensor& imageBuffer,
         | 
| 139 | 
            +
            	const bool debug) 
         | 
| 140 | 
            +
            {
         | 
| 141 | 
            +
              const int P = means3D.size(0);
         | 
| 142 | 
            +
              const int H = dL_dout_color.size(1);
         | 
| 143 | 
            +
              const int W = dL_dout_color.size(2);
         | 
| 144 | 
            +
              
         | 
| 145 | 
            +
              int M = 0;
         | 
| 146 | 
            +
              if(sh.size(0) != 0)
         | 
| 147 | 
            +
              {	
         | 
| 148 | 
            +
            	M = sh.size(1);
         | 
| 149 | 
            +
              }
         | 
| 150 | 
            +
             | 
| 151 | 
            +
              torch::Tensor dL_dmeans3D = torch::zeros({P, 3}, means3D.options());
         | 
| 152 | 
            +
              torch::Tensor dL_dmeans2D = torch::zeros({P, 3}, means3D.options());
         | 
| 153 | 
            +
              torch::Tensor dL_dcolors = torch::zeros({P, NUM_CHANNELS}, means3D.options());
         | 
| 154 | 
            +
              torch::Tensor dL_dconic = torch::zeros({P, 2, 2}, means3D.options());
         | 
| 155 | 
            +
              torch::Tensor dL_dopacity = torch::zeros({P, 1}, means3D.options());
         | 
| 156 | 
            +
              torch::Tensor dL_dcov3D = torch::zeros({P, 6}, means3D.options());
         | 
| 157 | 
            +
              torch::Tensor dL_dsh = torch::zeros({P, M, 3}, means3D.options());
         | 
| 158 | 
            +
              torch::Tensor dL_dscales = torch::zeros({P, 3}, means3D.options());
         | 
| 159 | 
            +
              torch::Tensor dL_drotations = torch::zeros({P, 4}, means3D.options());
         | 
| 160 | 
            +
              
         | 
| 161 | 
            +
              if(P != 0)
         | 
| 162 | 
            +
              {  
         | 
| 163 | 
            +
            	  CudaRasterizer::Rasterizer::backward(P, degree, M, R,
         | 
| 164 | 
            +
            	  background.contiguous().data<float>(),
         | 
| 165 | 
            +
            	  W, H, 
         | 
| 166 | 
            +
            	  means3D.contiguous().data<float>(),
         | 
| 167 | 
            +
            	  sh.contiguous().data<float>(),
         | 
| 168 | 
            +
            	  colors.contiguous().data<float>(),
         | 
| 169 | 
            +
            	  scales.data_ptr<float>(),
         | 
| 170 | 
            +
            	  scale_modifier,
         | 
| 171 | 
            +
            	  rotations.data_ptr<float>(),
         | 
| 172 | 
            +
            	  cov3D_precomp.contiguous().data<float>(),
         | 
| 173 | 
            +
            	  viewmatrix.contiguous().data<float>(),
         | 
| 174 | 
            +
            	  projmatrix.contiguous().data<float>(),
         | 
| 175 | 
            +
            	  campos.contiguous().data<float>(),
         | 
| 176 | 
            +
            	  tan_fovx,
         | 
| 177 | 
            +
            	  tan_fovy,
         | 
| 178 | 
            +
            	  radii.contiguous().data<int>(),
         | 
| 179 | 
            +
            	  reinterpret_cast<char*>(geomBuffer.contiguous().data_ptr()),
         | 
| 180 | 
            +
            	  reinterpret_cast<char*>(binningBuffer.contiguous().data_ptr()),
         | 
| 181 | 
            +
            	  reinterpret_cast<char*>(imageBuffer.contiguous().data_ptr()),
         | 
| 182 | 
            +
            	  dL_dout_color.contiguous().data<float>(),
         | 
| 183 | 
            +
            	  dL_dmeans2D.contiguous().data<float>(),
         | 
| 184 | 
            +
            	  dL_dconic.contiguous().data<float>(),  
         | 
| 185 | 
            +
            	  dL_dopacity.contiguous().data<float>(),
         | 
| 186 | 
            +
            	  dL_dcolors.contiguous().data<float>(),
         | 
| 187 | 
            +
            	  dL_dmeans3D.contiguous().data<float>(),
         | 
| 188 | 
            +
            	  dL_dcov3D.contiguous().data<float>(),
         | 
| 189 | 
            +
            	  dL_dsh.contiguous().data<float>(),
         | 
| 190 | 
            +
            	  dL_dscales.contiguous().data<float>(),
         | 
| 191 | 
            +
            	  dL_drotations.contiguous().data<float>(),
         | 
| 192 | 
            +
            	  debug);
         | 
| 193 | 
            +
              }
         | 
| 194 | 
            +
             | 
| 195 | 
            +
              return std::make_tuple(dL_dmeans2D, dL_dcolors, dL_dopacity, dL_dmeans3D, dL_dcov3D, dL_dsh, dL_dscales, dL_drotations);
         | 
| 196 | 
            +
            }
         | 
| 197 | 
            +
             | 
| 198 | 
            +
            torch::Tensor markVisible(
         | 
| 199 | 
            +
            		torch::Tensor& means3D,
         | 
| 200 | 
            +
            		torch::Tensor& viewmatrix,
         | 
| 201 | 
            +
            		torch::Tensor& projmatrix)
         | 
| 202 | 
            +
            { 
         | 
| 203 | 
            +
              const int P = means3D.size(0);
         | 
| 204 | 
            +
              
         | 
| 205 | 
            +
              torch::Tensor present = torch::full({P}, false, means3D.options().dtype(at::kBool));
         | 
| 206 | 
            +
             
         | 
| 207 | 
            +
              if(P != 0)
         | 
| 208 | 
            +
              {
         | 
| 209 | 
            +
            	CudaRasterizer::Rasterizer::markVisible(P,
         | 
| 210 | 
            +
            		means3D.contiguous().data<float>(),
         | 
| 211 | 
            +
            		viewmatrix.contiguous().data<float>(),
         | 
| 212 | 
            +
            		projmatrix.contiguous().data<float>(),
         | 
| 213 | 
            +
            		present.contiguous().data<bool>());
         | 
| 214 | 
            +
              }
         | 
| 215 | 
            +
              
         | 
| 216 | 
            +
              return present;
         | 
| 217 | 
            +
            }
         | 
    	
        submodules/diff-gaussian-rasterization/rasterize_points.h
    ADDED
    
    | @@ -0,0 +1,67 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            /*
         | 
| 2 | 
            +
             * Copyright (C) 2023, Inria
         | 
| 3 | 
            +
             * GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
             * All rights reserved.
         | 
| 5 | 
            +
             *
         | 
| 6 | 
            +
             * This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
             * under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
             *
         | 
| 9 | 
            +
             * For inquiries contact  [email protected]
         | 
| 10 | 
            +
             */
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            #pragma once
         | 
| 13 | 
            +
            #include <torch/extension.h>
         | 
| 14 | 
            +
            #include <cstdio>
         | 
| 15 | 
            +
            #include <tuple>
         | 
| 16 | 
            +
            #include <string>
         | 
| 17 | 
            +
            	
         | 
| 18 | 
            +
            std::tuple<int, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
         | 
| 19 | 
            +
            RasterizeGaussiansCUDA(
         | 
| 20 | 
            +
            	const torch::Tensor& background,
         | 
| 21 | 
            +
            	const torch::Tensor& means3D,
         | 
| 22 | 
            +
                const torch::Tensor& colors,
         | 
| 23 | 
            +
                const torch::Tensor& opacity,
         | 
| 24 | 
            +
            	const torch::Tensor& scales,
         | 
| 25 | 
            +
            	const torch::Tensor& rotations,
         | 
| 26 | 
            +
            	const float scale_modifier,
         | 
| 27 | 
            +
            	const torch::Tensor& cov3D_precomp,
         | 
| 28 | 
            +
            	const torch::Tensor& viewmatrix,
         | 
| 29 | 
            +
            	const torch::Tensor& projmatrix,
         | 
| 30 | 
            +
            	const float tan_fovx, 
         | 
| 31 | 
            +
            	const float tan_fovy,
         | 
| 32 | 
            +
                const int image_height,
         | 
| 33 | 
            +
                const int image_width,
         | 
| 34 | 
            +
            	const torch::Tensor& sh,
         | 
| 35 | 
            +
            	const int degree,
         | 
| 36 | 
            +
            	const torch::Tensor& campos,
         | 
| 37 | 
            +
            	const bool prefiltered,
         | 
| 38 | 
            +
            	const bool debug);
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
         | 
| 41 | 
            +
             RasterizeGaussiansBackwardCUDA(
         | 
| 42 | 
            +
             	const torch::Tensor& background,
         | 
| 43 | 
            +
            	const torch::Tensor& means3D,
         | 
| 44 | 
            +
            	const torch::Tensor& radii,
         | 
| 45 | 
            +
                const torch::Tensor& colors,
         | 
| 46 | 
            +
            	const torch::Tensor& scales,
         | 
| 47 | 
            +
            	const torch::Tensor& rotations,
         | 
| 48 | 
            +
            	const float scale_modifier,
         | 
| 49 | 
            +
            	const torch::Tensor& cov3D_precomp,
         | 
| 50 | 
            +
            	const torch::Tensor& viewmatrix,
         | 
| 51 | 
            +
                const torch::Tensor& projmatrix,
         | 
| 52 | 
            +
            	const float tan_fovx, 
         | 
| 53 | 
            +
            	const float tan_fovy,
         | 
| 54 | 
            +
                const torch::Tensor& dL_dout_color,
         | 
| 55 | 
            +
            	const torch::Tensor& sh,
         | 
| 56 | 
            +
            	const int degree,
         | 
| 57 | 
            +
            	const torch::Tensor& campos,
         | 
| 58 | 
            +
            	const torch::Tensor& geomBuffer,
         | 
| 59 | 
            +
            	const int R,
         | 
| 60 | 
            +
            	const torch::Tensor& binningBuffer,
         | 
| 61 | 
            +
            	const torch::Tensor& imageBuffer,
         | 
| 62 | 
            +
            	const bool debug);
         | 
| 63 | 
            +
            		
         | 
| 64 | 
            +
            torch::Tensor markVisible(
         | 
| 65 | 
            +
            		torch::Tensor& means3D,
         | 
| 66 | 
            +
            		torch::Tensor& viewmatrix,
         | 
| 67 | 
            +
            		torch::Tensor& projmatrix);
         | 
    	
        submodules/diff-gaussian-rasterization/setup.py
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #
         | 
| 2 | 
            +
            # Copyright (C) 2023, Inria
         | 
| 3 | 
            +
            # GRAPHDECO research group, https://team.inria.fr/graphdeco
         | 
| 4 | 
            +
            # All rights reserved.
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            # This software is free for non-commercial, research and evaluation use 
         | 
| 7 | 
            +
            # under the terms of the LICENSE.md file.
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # For inquiries contact  [email protected]
         | 
| 10 | 
            +
            #
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from setuptools import setup
         | 
| 13 | 
            +
            from torch.utils.cpp_extension import CUDAExtension, BuildExtension
         | 
| 14 | 
            +
            import os
         | 
| 15 | 
            +
            os.path.dirname(os.path.abspath(__file__))
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            setup(
         | 
| 18 | 
            +
                name="diff_gaussian_rasterization",
         | 
| 19 | 
            +
                packages=['diff_gaussian_rasterization'],
         | 
| 20 | 
            +
                ext_modules=[
         | 
| 21 | 
            +
                    CUDAExtension(
         | 
| 22 | 
            +
                        name="diff_gaussian_rasterization._C",
         | 
| 23 | 
            +
                        sources=[
         | 
| 24 | 
            +
                        "cuda_rasterizer/rasterizer_impl.cu",
         | 
| 25 | 
            +
                        "cuda_rasterizer/forward.cu",
         | 
| 26 | 
            +
                        "cuda_rasterizer/backward.cu",
         | 
| 27 | 
            +
                        "rasterize_points.cu",
         | 
| 28 | 
            +
                        "ext.cpp"],
         | 
| 29 | 
            +
                        extra_compile_args={"nvcc": ["-I" + os.path.join(os.path.dirname(os.path.abspath(__file__)), "third_party/glm/")]})
         | 
| 30 | 
            +
                    ],
         | 
| 31 | 
            +
                cmdclass={
         | 
| 32 | 
            +
                    'build_ext': BuildExtension
         | 
| 33 | 
            +
                }
         | 
| 34 | 
            +
            )
         | 
    	
        submodules/diff-gaussian-rasterization/third_party/glm/.appveyor.yml
    ADDED
    
    | @@ -0,0 +1,92 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            shallow_clone: true
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            platform:
         | 
| 4 | 
            +
              - x86
         | 
| 5 | 
            +
              - x64
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            configuration:
         | 
| 8 | 
            +
              - Debug
         | 
| 9 | 
            +
              - Release
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            image:
         | 
| 12 | 
            +
              - Visual Studio 2013
         | 
| 13 | 
            +
              - Visual Studio 2015
         | 
| 14 | 
            +
              - Visual Studio 2017
         | 
| 15 | 
            +
              - Visual Studio 2019
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            environment:
         | 
| 18 | 
            +
              matrix:
         | 
| 19 | 
            +
                - GLM_ARGUMENTS: -DGLM_TEST_FORCE_PURE=ON
         | 
| 20 | 
            +
                - GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_SSE2=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON
         | 
| 21 | 
            +
                - GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON
         | 
| 22 | 
            +
                - GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON -DGLM_TEST_ENABLE_CXX_14=ON
         | 
| 23 | 
            +
                - GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON -DGLM_TEST_ENABLE_CXX_17=ON
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            matrix:
         | 
| 26 | 
            +
                exclude:
         | 
| 27 | 
            +
                - image: Visual Studio 2013
         | 
| 28 | 
            +
                  GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON
         | 
| 29 | 
            +
                - image: Visual Studio 2013
         | 
| 30 | 
            +
                  GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON -DGLM_TEST_ENABLE_CXX_14=ON
         | 
| 31 | 
            +
                - image: Visual Studio 2013
         | 
| 32 | 
            +
                  GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON -DGLM_TEST_ENABLE_CXX_17=ON
         | 
| 33 | 
            +
                - image: Visual Studio 2013
         | 
| 34 | 
            +
                  configuration: Debug
         | 
| 35 | 
            +
                - image: Visual Studio 2015
         | 
| 36 | 
            +
                  GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_SSE2=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON
         | 
| 37 | 
            +
                - image: Visual Studio 2015
         | 
| 38 | 
            +
                  GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON -DGLM_TEST_ENABLE_CXX_14=ON
         | 
| 39 | 
            +
                - image: Visual Studio 2015
         | 
| 40 | 
            +
                  GLM_ARGUMENTS: -DGLM_TEST_ENABLE_SIMD_AVX=ON -DGLM_TEST_ENABLE_LANG_EXTENSIONS=ON -DGLM_TEST_ENABLE_CXX_17=ON
         | 
| 41 | 
            +
                - image: Visual Studio 2015
         | 
| 42 | 
            +
                  platform: x86
         | 
| 43 | 
            +
                - image: Visual Studio 2015
         | 
| 44 | 
            +
                  configuration: Debug
         | 
| 45 | 
            +
                - image: Visual Studio 2017
         | 
| 46 | 
            +
                  platform: x86
         | 
| 47 | 
            +
                - image: Visual Studio 2017
         | 
| 48 | 
            +
                  configuration: Debug
         | 
| 49 | 
            +
                - image: Visual Studio 2019
         | 
| 50 | 
            +
                  platform: x64
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            branches:
         | 
| 53 | 
            +
              only:
         | 
| 54 | 
            +
                - master
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            before_build:
         | 
| 57 | 
            +
              - ps: |
         | 
| 58 | 
            +
                  mkdir build
         | 
| 59 | 
            +
                  cd build
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                  if ("$env:APPVEYOR_JOB_NAME" -match "Image: Visual Studio 2013") {
         | 
| 62 | 
            +
                      $env:generator="Visual Studio 12 2013"
         | 
| 63 | 
            +
                  } 
         | 
| 64 | 
            +
                  if ("$env:APPVEYOR_JOB_NAME" -match "Image: Visual Studio 2015") {
         | 
| 65 | 
            +
                      $env:generator="Visual Studio 14 2015"
         | 
| 66 | 
            +
                  } 
         | 
| 67 | 
            +
                  if ("$env:APPVEYOR_JOB_NAME" -match "Image: Visual Studio 2017") {
         | 
| 68 | 
            +
                      $env:generator="Visual Studio 15 2017"
         | 
| 69 | 
            +
                  }
         | 
| 70 | 
            +
                  if ("$env:APPVEYOR_JOB_NAME" -match "Image: Visual Studio 2019") {
         | 
| 71 | 
            +
                      $env:generator="Visual Studio 16 2019"
         | 
| 72 | 
            +
                  }
         | 
| 73 | 
            +
                  if ($env:PLATFORM -eq "x64") {
         | 
| 74 | 
            +
                      $env:generator="$env:generator Win64"
         | 
| 75 | 
            +
                  }
         | 
| 76 | 
            +
                  echo generator="$env:generator"
         | 
| 77 | 
            +
                  cmake .. -G "$env:generator" -DCMAKE_INSTALL_PREFIX="$env:APPVEYOR_BUILD_FOLDER/install" -DGLM_QUIET=ON -DGLM_TEST_ENABLE=ON "$env:GLM_ARGUMENTS"
         | 
| 78 | 
            +
             | 
| 79 | 
            +
            build_script:
         | 
| 80 | 
            +
              - cmake --build . --parallel --config %CONFIGURATION% -- /m /v:minimal
         | 
| 81 | 
            +
              - cmake --build . --target install --parallel --config %CONFIGURATION% -- /m /v:minimal
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            test_script:
         | 
| 84 | 
            +
              - ctest --parallel 4 --verbose -C %CONFIGURATION%
         | 
| 85 | 
            +
              - cd ..
         | 
| 86 | 
            +
              - ps: |
         | 
| 87 | 
            +
                  mkdir build_test_cmake
         | 
| 88 | 
            +
                  cd build_test_cmake
         | 
| 89 | 
            +
                  cmake ..\test\cmake\ -G "$env:generator" -DCMAKE_PREFIX_PATH="$env:APPVEYOR_BUILD_FOLDER/install"
         | 
| 90 | 
            +
              - cmake --build . --parallel --config %CONFIGURATION% -- /m /v:minimal
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            deploy: off
         | 
    	
        submodules/diff-gaussian-rasterization/third_party/glm/.gitignore
    ADDED
    
    | @@ -0,0 +1,61 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Compiled Object files
         | 
| 2 | 
            +
            *.slo
         | 
| 3 | 
            +
            *.lo
         | 
| 4 | 
            +
            *.o
         | 
| 5 | 
            +
            *.obj
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            # Precompiled Headers
         | 
| 8 | 
            +
            *.gch
         | 
| 9 | 
            +
            *.pch
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            # Compiled Dynamic libraries
         | 
| 12 | 
            +
            *.so
         | 
| 13 | 
            +
            *.dylib
         | 
| 14 | 
            +
            *.dll
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # Fortran module files
         | 
| 17 | 
            +
            *.mod
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            # Compiled Static libraries
         | 
| 20 | 
            +
            *.lai
         | 
| 21 | 
            +
            *.la
         | 
| 22 | 
            +
            *.a
         | 
| 23 | 
            +
            *.lib
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            # Executables
         | 
| 26 | 
            +
            *.exe
         | 
| 27 | 
            +
            *.out
         | 
| 28 | 
            +
            *.app
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            # CMake
         | 
| 31 | 
            +
            CMakeCache.txt
         | 
| 32 | 
            +
            CMakeFiles
         | 
| 33 | 
            +
            cmake_install.cmake
         | 
| 34 | 
            +
            install_manifest.txt
         | 
| 35 | 
            +
            *.cmake
         | 
| 36 | 
            +
            !glmConfig.cmake
         | 
| 37 | 
            +
            !glmConfig-version.cmake
         | 
| 38 | 
            +
            # ^ May need to add future .cmake files as exceptions
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            # Test logs
         | 
| 41 | 
            +
            Testing/*
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            # Test input
         | 
| 44 | 
            +
            test/gtc/*.dds
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            # Project Files
         | 
| 47 | 
            +
            Makefile
         | 
| 48 | 
            +
            *.cbp
         | 
| 49 | 
            +
            *.user
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            # Misc.
         | 
| 52 | 
            +
            *.log
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            # local build(s)
         | 
| 55 | 
            +
            build*
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            /.vs
         | 
| 58 | 
            +
            /.vscode
         | 
| 59 | 
            +
            /CMakeSettings.json
         | 
| 60 | 
            +
            .DS_Store
         | 
| 61 | 
            +
            *.swp
         | 
