ostapagon commited on
Commit
039c399
·
1 Parent(s): 375b434

Add training back

Browse files
Files changed (1) hide show
  1. demo/gs_train.py +65 -185
demo/gs_train.py CHANGED
@@ -132,119 +132,8 @@ def train(
132
  densify_grad_threshold=densify_grad_threshold,
133
  random_background=random_background
134
  )
135
- try:
136
- import subprocess
137
- nvcc_version = subprocess.check_output(['nvcc', '--version']).decode('utf-8')
138
- print("NVCC Driver Version:", nvcc_version)
139
- except Exception as e:
140
- print("Error fetching NVCC Driver Version:", e)
141
-
142
- #
143
- # Copyright (C) 2023, Inria
144
- # GRAPHDECO research group, https://team.inria.fr/graphdeco
145
- # All rights reserved.
146
- #
147
- # This software is free for non-commercial, research and evaluation use
148
- # under the terms of the LICENSE.md file.
149
- #
150
- # For inquiries contact [email protected]
151
- #
152
- print("local_renderer")
153
- import torch
154
- import math
155
- from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
156
- from scene.gaussian_model import GaussianModel
157
- from utils.sh_utils import eval_sh
158
-
159
- def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
160
- """
161
- Render the scene.
162
-
163
- Background tensor (bg_color) must be on GPU!
164
- """
165
-
166
- # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
167
- screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
168
- try:
169
- screenspace_points.retain_grad()
170
- except:
171
- pass
172
-
173
- # Set up rasterization configuration
174
- tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
175
- tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
176
-
177
- kernel_size = 0.1
178
- subpixel_offset = torch.zeros((int(viewpoint_camera.image_height), int(viewpoint_camera.image_width), 2), dtype=torch.float32, device="cuda")
179
-
180
- raster_settings = GaussianRasterizationSettings(
181
- image_height=int(viewpoint_camera.image_height),
182
- image_width=int(viewpoint_camera.image_width),
183
- tanfovx=tanfovx,
184
- tanfovy=tanfovy,
185
- # kernel_size=kernel_size,
186
- # subpixel_offset=subpixel_offset,
187
- bg=bg_color,
188
- scale_modifier=scaling_modifier,
189
- viewmatrix=viewpoint_camera.world_view_transform,
190
- projmatrix=viewpoint_camera.full_proj_transform,
191
- sh_degree=pc.active_sh_degree,
192
- campos=viewpoint_camera.camera_center,
193
- prefiltered=False,
194
- debug=pipe.debug
195
- )
196
-
197
- rasterizer = GaussianRasterizer(raster_settings=raster_settings)
198
-
199
- means3D = pc.get_xyz
200
- means2D = screenspace_points
201
- opacity = pc.get_opacity
202
-
203
- # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
204
- # scaling / rotation by the rasterizer.
205
- scales = None
206
- rotations = None
207
- cov3D_precomp = None
208
- if pipe.compute_cov3D_python:
209
- cov3D_precomp = pc.get_covariance(scaling_modifier)
210
- else:
211
- scales = pc.get_scaling
212
- rotations = pc.get_rotation
213
-
214
- # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
215
- # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
216
- shs = None
217
- colors_precomp = None
218
- if override_color is None:
219
- if pipe.convert_SHs_python:
220
- shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
221
- dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
222
- dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
223
- sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
224
- colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
225
- else:
226
- shs = pc.get_features
227
- else:
228
- colors_precomp = override_color
229
-
230
- # Rasterize visible Gaussians to image, obtain their radii (on screen).
231
- rendered_image, radii = rasterizer(
232
- means3D = means3D,
233
- means2D = means2D,
234
- shs = shs,
235
- colors_precomp = colors_precomp,
236
- opacities = opacity,
237
- scales = scales,
238
- rotations = rotations,
239
- cov3D_precomp = cov3D_precomp)
240
-
241
- # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
242
- # They will be excluded from value updates used in the splitting criteria.
243
- return {"render": rendered_image,
244
- "viewspace_points": screenspace_points,
245
- "visibility_filter" : radii > 0,
246
- "radii": radii}
247
 
 
248
 
249
  args = TrainingArgs()
250
 
@@ -253,15 +142,6 @@ def train(
253
  checkpoint_iterations = args.checkpoint_iterations
254
  debug_from = args.debug_from
255
 
256
- pcd = torch.randn((90804, 3)).float().cuda()
257
- print("pcd: ", pcd.shape, pcd.dtype, pcd.min(), pcd.max(), pcd.device)
258
- print("distCUDA2: ", distCUDA2(pcd.cpu()))
259
- print("distCUDA2: ", distCUDA2(pcd.cuda()))
260
-
261
- dist2 = torch.clamp_min(distCUDA2(pcd.cuda()), 0.0000001)
262
- print("dist2.shape: ", dist2.shape)
263
-
264
-
265
  tb_writer = prepare_output_and_logger(dataset)
266
 
267
  gaussians = GaussianModel(dataset.sh_degree)
@@ -281,73 +161,73 @@ def train(
281
  first_iter += 1
282
 
283
  point_cloud_path = ""
284
- # progress = gr.Progress() # Initialize the progress bar
285
- # for iteration in range(first_iter, opt.iterations + 1):
286
- # iter_start.record()
287
- # gaussians.update_learning_rate(iteration)
288
 
289
- # # Every 1000 its we increase the levels of SH up to a maximum degree
290
- # if iteration % 1000 == 0:
291
- # gaussians.oneupSHdegree()
292
 
293
- # # Pick a random Camera
294
- # if not viewpoint_stack:
295
- # viewpoint_stack = scene.getTrainCameras().copy()
296
- # viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
297
 
298
- # # Render
299
- # if (iteration - 1) == debug_from:
300
- # pipe.debug = True
301
- # bg = torch.rand((3), device=DEVICE) if opt.random_background else background
302
 
303
- # render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
304
- # image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
305
-
306
- # # Loss
307
- # gt_image = viewpoint_cam.original_image.cuda()
308
- # Ll1 = l1_loss(image, gt_image)
309
- # loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
310
- # loss.backward()
311
- # iter_end.record()
312
 
313
- # with torch.no_grad():
314
- # # Progress bar
315
- # ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
316
- # if iteration % 10 == 0:
317
- # progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
318
- # progress_bar.update(10)
319
- # progress(iteration / opt.iterations) # Update Gradio progress bar
320
- # if iteration == opt.iterations:
321
- # progress_bar.close()
322
-
323
- # # Log and save
324
- # training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
325
- # if (iteration == opt.iterations):
326
- # point_cloud_path = os.path.join(os.path.join(dataset.model_path, "point_cloud/iteration_{}".format(iteration)), "point_cloud.ply")
327
- # print("\n[ITER {}] Saving Gaussians to {}".format(iteration, point_cloud_path))
328
- # scene.save(iteration)
329
-
330
- # # Densification
331
- # if iteration < opt.densify_until_iter:
332
- # # Keep track of max radii in image-space for pruning
333
- # gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
334
- # gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
335
-
336
- # if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
337
- # size_threshold = 20 if iteration > opt.opacity_reset_interval else None
338
- # gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
339
-
340
- # if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
341
- # gaussians.reset_opacity()
342
-
343
- # # Optimizer step
344
- # if iteration < opt.iterations:
345
- # gaussians.optimizer.step()
346
- # gaussians.optimizer.zero_grad(set_to_none = True)
347
-
348
- # if (iteration == opt.iterations):
349
- # print("\n[ITER {}] Saving Checkpoint".format(iteration))
350
- # torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
351
 
352
 
353
  from os import makedirs
@@ -359,8 +239,8 @@ def train(
359
  """
360
  render_resize_method: crop, pad
361
  """
362
- # gaussians = GaussianModel(dataset.sh_degree)
363
- # scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
364
 
365
  iteration = scene.loaded_iter
366
 
 
132
  densify_grad_threshold=densify_grad_threshold,
133
  random_background=random_background
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ print("local_renderer")
137
 
138
  args = TrainingArgs()
139
 
 
142
  checkpoint_iterations = args.checkpoint_iterations
143
  debug_from = args.debug_from
144
 
 
 
 
 
 
 
 
 
 
145
  tb_writer = prepare_output_and_logger(dataset)
146
 
147
  gaussians = GaussianModel(dataset.sh_degree)
 
161
  first_iter += 1
162
 
163
  point_cloud_path = ""
164
+ progress = gr.Progress() # Initialize the progress bar
165
+ for iteration in range(first_iter, opt.iterations + 1):
166
+ iter_start.record()
167
+ gaussians.update_learning_rate(iteration)
168
 
169
+ # Every 1000 its we increase the levels of SH up to a maximum degree
170
+ if iteration % 1000 == 0:
171
+ gaussians.oneupSHdegree()
172
 
173
+ # Pick a random Camera
174
+ if not viewpoint_stack:
175
+ viewpoint_stack = scene.getTrainCameras().copy()
176
+ viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
177
 
178
+ # Render
179
+ if (iteration - 1) == debug_from:
180
+ pipe.debug = True
181
+ bg = torch.rand((3), device=DEVICE) if opt.random_background else background
182
 
183
+ render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
184
+ image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
185
+
186
+ # Loss
187
+ gt_image = viewpoint_cam.original_image.cuda()
188
+ Ll1 = l1_loss(image, gt_image)
189
+ loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
190
+ loss.backward()
191
+ iter_end.record()
192
 
193
+ with torch.no_grad():
194
+ # Progress bar
195
+ ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
196
+ if iteration % 10 == 0:
197
+ progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
198
+ progress_bar.update(10)
199
+ progress(iteration / opt.iterations) # Update Gradio progress bar
200
+ if iteration == opt.iterations:
201
+ progress_bar.close()
202
+
203
+ # Log and save
204
+ training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
205
+ if (iteration == opt.iterations):
206
+ point_cloud_path = os.path.join(os.path.join(dataset.model_path, "point_cloud/iteration_{}".format(iteration)), "point_cloud.ply")
207
+ print("\n[ITER {}] Saving Gaussians to {}".format(iteration, point_cloud_path))
208
+ scene.save(iteration)
209
+
210
+ # Densification
211
+ if iteration < opt.densify_until_iter:
212
+ # Keep track of max radii in image-space for pruning
213
+ gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
214
+ gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
215
+
216
+ if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
217
+ size_threshold = 20 if iteration > opt.opacity_reset_interval else None
218
+ gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
219
+
220
+ if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
221
+ gaussians.reset_opacity()
222
+
223
+ # Optimizer step
224
+ if iteration < opt.iterations:
225
+ gaussians.optimizer.step()
226
+ gaussians.optimizer.zero_grad(set_to_none = True)
227
+
228
+ if (iteration == opt.iterations):
229
+ print("\n[ITER {}] Saving Checkpoint".format(iteration))
230
+ torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
231
 
232
 
233
  from os import makedirs
 
239
  """
240
  render_resize_method: crop, pad
241
  """
242
+ gaussians = GaussianModel(dataset.sh_degree)
243
+ scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
244
 
245
  iteration = scene.loaded_iter
246