Spaces:
Running
on
Zero
Running
on
Zero
Add training back
Browse files- demo/gs_train.py +65 -185
demo/gs_train.py
CHANGED
@@ -132,119 +132,8 @@ def train(
|
|
132 |
densify_grad_threshold=densify_grad_threshold,
|
133 |
random_background=random_background
|
134 |
)
|
135 |
-
try:
|
136 |
-
import subprocess
|
137 |
-
nvcc_version = subprocess.check_output(['nvcc', '--version']).decode('utf-8')
|
138 |
-
print("NVCC Driver Version:", nvcc_version)
|
139 |
-
except Exception as e:
|
140 |
-
print("Error fetching NVCC Driver Version:", e)
|
141 |
-
|
142 |
-
#
|
143 |
-
# Copyright (C) 2023, Inria
|
144 |
-
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
145 |
-
# All rights reserved.
|
146 |
-
#
|
147 |
-
# This software is free for non-commercial, research and evaluation use
|
148 |
-
# under the terms of the LICENSE.md file.
|
149 |
-
#
|
150 |
-
# For inquiries contact [email protected]
|
151 |
-
#
|
152 |
-
print("local_renderer")
|
153 |
-
import torch
|
154 |
-
import math
|
155 |
-
from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
|
156 |
-
from scene.gaussian_model import GaussianModel
|
157 |
-
from utils.sh_utils import eval_sh
|
158 |
-
|
159 |
-
def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
|
160 |
-
"""
|
161 |
-
Render the scene.
|
162 |
-
|
163 |
-
Background tensor (bg_color) must be on GPU!
|
164 |
-
"""
|
165 |
-
|
166 |
-
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
|
167 |
-
screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
|
168 |
-
try:
|
169 |
-
screenspace_points.retain_grad()
|
170 |
-
except:
|
171 |
-
pass
|
172 |
-
|
173 |
-
# Set up rasterization configuration
|
174 |
-
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
|
175 |
-
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
|
176 |
-
|
177 |
-
kernel_size = 0.1
|
178 |
-
subpixel_offset = torch.zeros((int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), dtype=torch.float32, device="cuda")
|
179 |
-
|
180 |
-
raster_settings = GaussianRasterizationSettings(
|
181 |
-
image_height=int(viewpoint_camera.image_height),
|
182 |
-
image_width=int(viewpoint_camera.image_width),
|
183 |
-
tanfovx=tanfovx,
|
184 |
-
tanfovy=tanfovy,
|
185 |
-
# kernel_size=kernel_size,
|
186 |
-
# subpixel_offset=subpixel_offset,
|
187 |
-
bg=bg_color,
|
188 |
-
scale_modifier=scaling_modifier,
|
189 |
-
viewmatrix=viewpoint_camera.world_view_transform,
|
190 |
-
projmatrix=viewpoint_camera.full_proj_transform,
|
191 |
-
sh_degree=pc.active_sh_degree,
|
192 |
-
campos=viewpoint_camera.camera_center,
|
193 |
-
prefiltered=False,
|
194 |
-
debug=pipe.debug
|
195 |
-
)
|
196 |
-
|
197 |
-
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
|
198 |
-
|
199 |
-
means3D = pc.get_xyz
|
200 |
-
means2D = screenspace_points
|
201 |
-
opacity = pc.get_opacity
|
202 |
-
|
203 |
-
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
|
204 |
-
# scaling / rotation by the rasterizer.
|
205 |
-
scales = None
|
206 |
-
rotations = None
|
207 |
-
cov3D_precomp = None
|
208 |
-
if pipe.compute_cov3D_python:
|
209 |
-
cov3D_precomp = pc.get_covariance(scaling_modifier)
|
210 |
-
else:
|
211 |
-
scales = pc.get_scaling
|
212 |
-
rotations = pc.get_rotation
|
213 |
-
|
214 |
-
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
|
215 |
-
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
|
216 |
-
shs = None
|
217 |
-
colors_precomp = None
|
218 |
-
if override_color is None:
|
219 |
-
if pipe.convert_SHs_python:
|
220 |
-
shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
|
221 |
-
dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
|
222 |
-
dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
|
223 |
-
sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
|
224 |
-
colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
|
225 |
-
else:
|
226 |
-
shs = pc.get_features
|
227 |
-
else:
|
228 |
-
colors_precomp = override_color
|
229 |
-
|
230 |
-
# Rasterize visible Gaussians to image, obtain their radii (on screen).
|
231 |
-
rendered_image, radii = rasterizer(
|
232 |
-
means3D = means3D,
|
233 |
-
means2D = means2D,
|
234 |
-
shs = shs,
|
235 |
-
colors_precomp = colors_precomp,
|
236 |
-
opacities = opacity,
|
237 |
-
scales = scales,
|
238 |
-
rotations = rotations,
|
239 |
-
cov3D_precomp = cov3D_precomp)
|
240 |
-
|
241 |
-
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
|
242 |
-
# They will be excluded from value updates used in the splitting criteria.
|
243 |
-
return {"render": rendered_image,
|
244 |
-
"viewspace_points": screenspace_points,
|
245 |
-
"visibility_filter" : radii > 0,
|
246 |
-
"radii": radii}
|
247 |
|
|
|
248 |
|
249 |
args = TrainingArgs()
|
250 |
|
@@ -253,15 +142,6 @@ def train(
|
|
253 |
checkpoint_iterations = args.checkpoint_iterations
|
254 |
debug_from = args.debug_from
|
255 |
|
256 |
-
pcd = torch.randn((90804, 3)).float().cuda()
|
257 |
-
print("pcd: ", pcd.shape, pcd.dtype, pcd.min(), pcd.max(), pcd.device)
|
258 |
-
print("distCUDA2: ", distCUDA2(pcd.cpu()))
|
259 |
-
print("distCUDA2: ", distCUDA2(pcd.cuda()))
|
260 |
-
|
261 |
-
dist2 = torch.clamp_min(distCUDA2(pcd.cuda()), 0.0000001)
|
262 |
-
print("dist2.shape: ", dist2.shape)
|
263 |
-
|
264 |
-
|
265 |
tb_writer = prepare_output_and_logger(dataset)
|
266 |
|
267 |
gaussians = GaussianModel(dataset.sh_degree)
|
@@ -281,73 +161,73 @@ def train(
|
|
281 |
first_iter += 1
|
282 |
|
283 |
point_cloud_path = ""
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
|
352 |
|
353 |
from os import makedirs
|
@@ -359,8 +239,8 @@ def train(
|
|
359 |
"""
|
360 |
render_resize_method: crop, pad
|
361 |
"""
|
362 |
-
|
363 |
-
|
364 |
|
365 |
iteration = scene.loaded_iter
|
366 |
|
|
|
132 |
densify_grad_threshold=densify_grad_threshold,
|
133 |
random_background=random_background
|
134 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
+
print("local_renderer")
|
137 |
|
138 |
args = TrainingArgs()
|
139 |
|
|
|
142 |
checkpoint_iterations = args.checkpoint_iterations
|
143 |
debug_from = args.debug_from
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
tb_writer = prepare_output_and_logger(dataset)
|
146 |
|
147 |
gaussians = GaussianModel(dataset.sh_degree)
|
|
|
161 |
first_iter += 1
|
162 |
|
163 |
point_cloud_path = ""
|
164 |
+
progress = gr.Progress() # Initialize the progress bar
|
165 |
+
for iteration in range(first_iter, opt.iterations + 1):
|
166 |
+
iter_start.record()
|
167 |
+
gaussians.update_learning_rate(iteration)
|
168 |
|
169 |
+
# Every 1000 its we increase the levels of SH up to a maximum degree
|
170 |
+
if iteration % 1000 == 0:
|
171 |
+
gaussians.oneupSHdegree()
|
172 |
|
173 |
+
# Pick a random Camera
|
174 |
+
if not viewpoint_stack:
|
175 |
+
viewpoint_stack = scene.getTrainCameras().copy()
|
176 |
+
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
|
177 |
|
178 |
+
# Render
|
179 |
+
if (iteration - 1) == debug_from:
|
180 |
+
pipe.debug = True
|
181 |
+
bg = torch.rand((3), device=DEVICE) if opt.random_background else background
|
182 |
|
183 |
+
render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
|
184 |
+
image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
|
185 |
+
|
186 |
+
# Loss
|
187 |
+
gt_image = viewpoint_cam.original_image.cuda()
|
188 |
+
Ll1 = l1_loss(image, gt_image)
|
189 |
+
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
|
190 |
+
loss.backward()
|
191 |
+
iter_end.record()
|
192 |
|
193 |
+
with torch.no_grad():
|
194 |
+
# Progress bar
|
195 |
+
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
|
196 |
+
if iteration % 10 == 0:
|
197 |
+
progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
|
198 |
+
progress_bar.update(10)
|
199 |
+
progress(iteration / opt.iterations) # Update Gradio progress bar
|
200 |
+
if iteration == opt.iterations:
|
201 |
+
progress_bar.close()
|
202 |
+
|
203 |
+
# Log and save
|
204 |
+
training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
|
205 |
+
if (iteration == opt.iterations):
|
206 |
+
point_cloud_path = os.path.join(os.path.join(dataset.model_path, "point_cloud/iteration_{}".format(iteration)), "point_cloud.ply")
|
207 |
+
print("\n[ITER {}] Saving Gaussians to {}".format(iteration, point_cloud_path))
|
208 |
+
scene.save(iteration)
|
209 |
+
|
210 |
+
# Densification
|
211 |
+
if iteration < opt.densify_until_iter:
|
212 |
+
# Keep track of max radii in image-space for pruning
|
213 |
+
gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
|
214 |
+
gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
|
215 |
+
|
216 |
+
if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
|
217 |
+
size_threshold = 20 if iteration > opt.opacity_reset_interval else None
|
218 |
+
gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
|
219 |
+
|
220 |
+
if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
|
221 |
+
gaussians.reset_opacity()
|
222 |
+
|
223 |
+
# Optimizer step
|
224 |
+
if iteration < opt.iterations:
|
225 |
+
gaussians.optimizer.step()
|
226 |
+
gaussians.optimizer.zero_grad(set_to_none = True)
|
227 |
+
|
228 |
+
if (iteration == opt.iterations):
|
229 |
+
print("\n[ITER {}] Saving Checkpoint".format(iteration))
|
230 |
+
torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
|
231 |
|
232 |
|
233 |
from os import makedirs
|
|
|
239 |
"""
|
240 |
render_resize_method: crop, pad
|
241 |
"""
|
242 |
+
gaussians = GaussianModel(dataset.sh_degree)
|
243 |
+
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
|
244 |
|
245 |
iteration = scene.loaded_iter
|
246 |
|