mast3r-3dgs / demo /gs_train.py
ostapagon's picture
Fix typo
c0a1a32
raw
history blame
17.5 kB
import sys
import os
import torch
from random import randint
import uuid
from tqdm.auto import tqdm
import gradio as gr
import importlib.util
from dataclasses import dataclass, field
from demo_globals import DEVICE
import spaces
from simple_knn._C import distCUDA2
@dataclass
class PipelineParams:
convert_SHs_python: bool = False
compute_cov3D_python: bool = False
debug: bool = False
@dataclass
class OptimizationParams:
iterations: int = 7000
position_lr_init: float = 0.00016
position_lr_final: float = 0.0000016
position_lr_delay_mult: float = 0.01
position_lr_max_steps: int = 30_000
feature_lr: float = 0.0025
opacity_lr: float = 0.05
scaling_lr: float = 0.005
rotation_lr: float = 0.001
percent_dense: float = 0.01
lambda_dssim: float = 0.2
densification_interval: int = 100
opacity_reset_interval: int = 3000
densify_from_iter: int = 500
densify_until_iter: int = 15_000
densify_grad_threshold: float = 0.0002
random_background: bool = False
@dataclass
class ModelParams:
sh_degree: int = 3
source_path: str = "../data/scenes/turtle/" # Default path, adjust as needed
model_path: str = ""
images: str = "images"
resolution: int = -1
white_background: bool = True
data_device: str = "cuda"
eval: bool = False
@dataclass
class TrainingArgs:
ip: str = "0.0.0.0"
port: int = 6007
debug_from: int = -1
detect_anomaly: bool = False
test_iterations: list[int] = field(default_factory=lambda: [7_000, 30_000])
save_iterations: list[int] = field(default_factory=lambda: [7_000, 30_000])
quiet: bool = False
checkpoint_iterations: list[int] = field(default_factory=lambda: [7_000, 15_000, 30_000])
start_checkpoint: str = None
@spaces.GPU(duration=20)
def train(
data_source_path, sh_degree, model_path, images, resolution, white_background, data_device, eval,
convert_SHs_python, compute_cov3D_python, debug,
iterations, position_lr_init, position_lr_final, position_lr_delay_mult,
position_lr_max_steps, feature_lr, opacity_lr, scaling_lr, rotation_lr,
percent_dense, lambda_dssim, densification_interval, opacity_reset_interval,
densify_from_iter, densify_until_iter, densify_grad_threshold, random_background
):
# Add the path to the gaussian-splatting repository
if 'GaussianRasterizer' not in globals():
gaussian_splatting_path = 'wild-gaussian-splatting/gaussian-splatting/'
sys.path.append(gaussian_splatting_path)
# Import necessary modules from the gaussian-splatting directory
from utils.loss_utils import l1_loss, ssim
# from gaussian_renderer import render
from scene import Scene, GaussianModel
from utils.general_utils import safe_state
from utils.image_utils import psnr
from utils.graphics_utils import focal2fov, fov2focal, getProjectionMatrix
# Dynamically import the train module from the gaussian-splatting directory
train_spec = importlib.util.spec_from_file_location("gaussian_splatting_train", os.path.join(gaussian_splatting_path, "train.py"))
gaussian_splatting_train = importlib.util.module_from_spec(train_spec)
train_spec.loader.exec_module(gaussian_splatting_train)
# Import the necessary functions from the dynamically loaded module
prepare_output_and_logger = gaussian_splatting_train.prepare_output_and_logger
training_report = gaussian_splatting_train.training_report
print(data_source_path)
# Create instances of the parameter dataclasses
dataset = ModelParams(
sh_degree=sh_degree,
source_path=data_source_path,
model_path=model_path,
images=images,
resolution=resolution,
white_background=white_background,
data_device=data_device,
eval=eval
)
pipe = PipelineParams(
convert_SHs_python=convert_SHs_python,
compute_cov3D_python=compute_cov3D_python,
debug=debug
)
opt = OptimizationParams(
iterations=iterations,
position_lr_init=position_lr_init,
position_lr_final=position_lr_final,
position_lr_delay_mult=position_lr_delay_mult,
position_lr_max_steps=position_lr_max_steps,
feature_lr=feature_lr,
opacity_lr=opacity_lr,
scaling_lr=scaling_lr,
rotation_lr=rotation_lr,
percent_dense=percent_dense,
lambda_dssim=lambda_dssim,
densification_interval=densification_interval,
opacity_reset_interval=opacity_reset_interval,
densify_from_iter=densify_from_iter,
densify_until_iter=densify_until_iter,
densify_grad_threshold=densify_grad_threshold,
random_background=random_background
)
try:
import subprocess
nvcc_version = subprocess.check_output(['nvcc', '--version']).decode('utf-8')
print("NVCC Driver Version:", nvcc_version)
except Exception as e:
print("Error fetching NVCC Driver Version:", e)
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact [email protected]
#
print("local_renderer")
import torch
import math
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
from scene.gaussian_model import GaussianModel
from utils.sh_utils import eval_sh
def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
"""
Render the scene.
Background tensor (bg_color) must be on GPU!
"""
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
try:
screenspace_points.retain_grad()
except:
pass
# Set up rasterization configuration
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
kernel_size = 0.1
subpixel_offset = torch.zeros((int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), dtype=torch.float32, device="cuda")
raster_settings = GaussianRasterizationSettings(
image_height=int(viewpoint_camera.image_height),
image_width=int(viewpoint_camera.image_width),
tanfovx=tanfovx,
tanfovy=tanfovy,
# kernel_size=kernel_size,
# subpixel_offset=subpixel_offset,
bg=bg_color,
scale_modifier=scaling_modifier,
viewmatrix=viewpoint_camera.world_view_transform,
projmatrix=viewpoint_camera.full_proj_transform,
sh_degree=pc.active_sh_degree,
campos=viewpoint_camera.camera_center,
prefiltered=False,
debug=pipe.debug
)
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
means3D = pc.get_xyz
means2D = screenspace_points
opacity = pc.get_opacity
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
# scaling / rotation by the rasterizer.
scales = None
rotations = None
cov3D_precomp = None
if pipe.compute_cov3D_python:
cov3D_precomp = pc.get_covariance(scaling_modifier)
else:
scales = pc.get_scaling
rotations = pc.get_rotation
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
shs = None
colors_precomp = None
if override_color is None:
if pipe.convert_SHs_python:
shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
else:
shs = pc.get_features
else:
colors_precomp = override_color
# Rasterize visible Gaussians to image, obtain their radii (on screen).
rendered_image, radii = rasterizer(
means3D = means3D,
means2D = means2D,
shs = shs,
colors_precomp = colors_precomp,
opacities = opacity,
scales = scales,
rotations = rotations,
cov3D_precomp = cov3D_precomp)
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
# They will be excluded from value updates used in the splitting criteria.
return {"render": rendered_image,
"viewspace_points": screenspace_points,
"visibility_filter" : radii > 0,
"radii": radii}
args = TrainingArgs()
testing_iterations = args.test_iterations
saving_iterations = args.save_iterations
checkpoint_iterations = args.checkpoint_iterations
debug_from = args.debug_from
pcd = torch.randn((90804, 3)).float().cuda()
print("pcd: ", pcd.shape, pcd.dtype, pcd.min(), pcd.max(), pcd.device)
print("distCUDA2: ", distCUDA2(pcd.cpu()))
print("distCUDA2: ", distCUDA2(pcd.cuda()))
dist2 = torch.clamp_min(distCUDA2(pcd.cuda()), 0.0000001)
print("dist2.shape: ", dist2.shape)
tb_writer = prepare_output_and_logger(dataset)
gaussians = GaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians)
gaussians.training_setup(opt)
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
iter_start = torch.cuda.Event(enable_timing = True)
iter_end = torch.cuda.Event(enable_timing = True)
viewpoint_stack = None
ema_loss_for_log = 0.0
first_iter = 0
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
first_iter += 1
point_cloud_path = ""
# progress = gr.Progress() # Initialize the progress bar
# for iteration in range(first_iter, opt.iterations + 1):
# iter_start.record()
# gaussians.update_learning_rate(iteration)
# # Every 1000 its we increase the levels of SH up to a maximum degree
# if iteration % 1000 == 0:
# gaussians.oneupSHdegree()
# # Pick a random Camera
# if not viewpoint_stack:
# viewpoint_stack = scene.getTrainCameras().copy()
# viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
# # Render
# if (iteration - 1) == debug_from:
# pipe.debug = True
# bg = torch.rand((3), device=DEVICE) if opt.random_background else background
# render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
# image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
# # Loss
# gt_image = viewpoint_cam.original_image.cuda()
# Ll1 = l1_loss(image, gt_image)
# loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
# loss.backward()
# iter_end.record()
# with torch.no_grad():
# # Progress bar
# ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
# if iteration % 10 == 0:
# progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
# progress_bar.update(10)
# progress(iteration / opt.iterations) # Update Gradio progress bar
# if iteration == opt.iterations:
# progress_bar.close()
# # Log and save
# training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
# if (iteration == opt.iterations):
# point_cloud_path = os.path.join(os.path.join(dataset.model_path, "point_cloud/iteration_{}".format(iteration)), "point_cloud.ply")
# print("\n[ITER {}] Saving Gaussians to {}".format(iteration, point_cloud_path))
# scene.save(iteration)
# # Densification
# if iteration < opt.densify_until_iter:
# # Keep track of max radii in image-space for pruning
# gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
# gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
# if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
# size_threshold = 20 if iteration > opt.opacity_reset_interval else None
# gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
# if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
# gaussians.reset_opacity()
# # Optimizer step
# if iteration < opt.iterations:
# gaussians.optimizer.step()
# gaussians.optimizer.zero_grad(set_to_none = True)
# if (iteration == opt.iterations):
# print("\n[ITER {}] Saving Checkpoint".format(iteration))
# torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
from os import makedirs
import torchvision
import subprocess
@torch.no_grad()
def render_path(dataset : ModelParams, iteration : int, pipeline : PipelineParams, render_resize_method='crop'):
"""
render_resize_method: crop, pad
"""
# gaussians = GaussianModel(dataset.sh_degree)
# scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
iteration = scene.loaded_iter
bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
model_path = dataset.model_path
name = "render"
views = scene.getRenderCameras()
# print(len(views))
render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
makedirs(render_path, exist_ok=True)
for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
if render_resize_method == 'crop':
image_size = 256
elif render_resize_method == 'pad':
image_size = max(view.image_width, view.image_height)
else:
raise NotImplementedError
view.original_image = torch.zeros((3, image_size, image_size), device=view.original_image.device)
focal_length_x = fov2focal(view.FoVx, view.image_width)
focal_length_y = fov2focal(view.FoVy, view.image_height)
view.image_width = image_size
view.image_height = image_size
view.FoVx = focal2fov(focal_length_x, image_size)
view.FoVy = focal2fov(focal_length_y, image_size)
view.projection_matrix = getProjectionMatrix(znear=view.znear, zfar=view.zfar, fovX=view.FoVx, fovY=view.FoVy).transpose(0,1).cuda().float()
view.full_proj_transform = (view.world_view_transform.unsqueeze(0).bmm(view.projection_matrix.unsqueeze(0))).squeeze(0)
# print("background.device: ", background.device)
# print("view.device: ", view.original_image.device)
render_pkg = render(view, gaussians, pipeline, background)
rendering = render_pkg["render"]
torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
# Use ffmpeg to output video
renders_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders.mp4")
# Use ffmpeg to output video
subprocess.run(["ffmpeg", "-y",
"-framerate", "24",
"-i", os.path.join(render_path, "%05d.png"),
"-vf", "pad=ceil(iw/2)*2:ceil(ih/2)*2",
"-c:v", "libx264",
"-pix_fmt", "yuv420p",
"-crf", "23",
# "-pix_fmt", "yuv420p", # Set pixel format for compatibility
renders_path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
return renders_path
renders_path = render_path(dataset, opt.iterations, pipe, render_resize_method='crop')
torch.cuda.empty_cache()
return renders_path, point_cloud_path