ostapient commited on
Commit
d47c0cf
·
1 Parent(s): a2ddf14

Just giving up on this idea

Browse files
Files changed (1) hide show
  1. demo/gs_train.py +62 -64
demo/gs_train.py CHANGED
@@ -12,8 +12,6 @@ from demo_globals import DEVICE
12
  import spaces
13
  from simple_knn._C import distCUDA2
14
 
15
-
16
-
17
  @dataclass
18
  class PipelineParams:
19
  convert_SHs_python: bool = False
@@ -169,73 +167,73 @@ def train(
169
  first_iter += 1
170
 
171
  point_cloud_path = ""
172
- progress = gr.Progress() # Initialize the progress bar
173
- for iteration in range(first_iter, opt.iterations + 1):
174
- iter_start.record()
175
- gaussians.update_learning_rate(iteration)
176
 
177
- # Every 1000 its we increase the levels of SH up to a maximum degree
178
- if iteration % 1000 == 0:
179
- gaussians.oneupSHdegree()
180
 
181
- # Pick a random Camera
182
- if not viewpoint_stack:
183
- viewpoint_stack = scene.getTrainCameras().copy()
184
- viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
185
 
186
- # Render
187
- if (iteration - 1) == debug_from:
188
- pipe.debug = True
189
- bg = torch.rand((3), device=DEVICE) if opt.random_background else background
190
 
191
- render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
192
- image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
193
-
194
- # Loss
195
- gt_image = viewpoint_cam.original_image.cuda()
196
- Ll1 = l1_loss(image, gt_image)
197
- loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
198
- loss.backward()
199
- iter_end.record()
200
 
201
- with torch.no_grad():
202
- # Progress bar
203
- ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
204
- if iteration % 10 == 0:
205
- progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
206
- progress_bar.update(10)
207
- progress(iteration / opt.iterations) # Update Gradio progress bar
208
- if iteration == opt.iterations:
209
- progress_bar.close()
210
-
211
- # Log and save
212
- training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
213
- if (iteration == opt.iterations):
214
- point_cloud_path = os.path.join(os.path.join(dataset.model_path, "point_cloud/iteration_{}".format(iteration)), "point_cloud.ply")
215
- print("\n[ITER {}] Saving Gaussians to {}".format(iteration, point_cloud_path))
216
- scene.save(iteration)
217
-
218
- # Densification
219
- if iteration < opt.densify_until_iter:
220
- # Keep track of max radii in image-space for pruning
221
- gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
222
- gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
223
-
224
- if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
225
- size_threshold = 20 if iteration > opt.opacity_reset_interval else None
226
- gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
227
-
228
- if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
229
- gaussians.reset_opacity()
230
-
231
- # Optimizer step
232
- if iteration < opt.iterations:
233
- gaussians.optimizer.step()
234
- gaussians.optimizer.zero_grad(set_to_none = True)
235
-
236
- if (iteration == opt.iterations):
237
- print("\n[ITER {}] Saving Checkpoint".format(iteration))
238
- torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
239
 
240
 
241
  from os import makedirs
 
12
  import spaces
13
  from simple_knn._C import distCUDA2
14
 
 
 
15
  @dataclass
16
  class PipelineParams:
17
  convert_SHs_python: bool = False
 
167
  first_iter += 1
168
 
169
  point_cloud_path = ""
170
+ # progress = gr.Progress() # Initialize the progress bar
171
+ # for iteration in range(first_iter, opt.iterations + 1):
172
+ # iter_start.record()
173
+ # gaussians.update_learning_rate(iteration)
174
 
175
+ # # Every 1000 its we increase the levels of SH up to a maximum degree
176
+ # if iteration % 1000 == 0:
177
+ # gaussians.oneupSHdegree()
178
 
179
+ # # Pick a random Camera
180
+ # if not viewpoint_stack:
181
+ # viewpoint_stack = scene.getTrainCameras().copy()
182
+ # viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
183
 
184
+ # # Render
185
+ # if (iteration - 1) == debug_from:
186
+ # pipe.debug = True
187
+ # bg = torch.rand((3), device=DEVICE) if opt.random_background else background
188
 
189
+ # render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
190
+ # image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
191
+
192
+ # # Loss
193
+ # gt_image = viewpoint_cam.original_image.cuda()
194
+ # Ll1 = l1_loss(image, gt_image)
195
+ # loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
196
+ # loss.backward()
197
+ # iter_end.record()
198
 
199
+ # with torch.no_grad():
200
+ # # Progress bar
201
+ # ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
202
+ # if iteration % 10 == 0:
203
+ # progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
204
+ # progress_bar.update(10)
205
+ # progress(iteration / opt.iterations) # Update Gradio progress bar
206
+ # if iteration == opt.iterations:
207
+ # progress_bar.close()
208
+
209
+ # # Log and save
210
+ # training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
211
+ # if (iteration == opt.iterations):
212
+ # point_cloud_path = os.path.join(os.path.join(dataset.model_path, "point_cloud/iteration_{}".format(iteration)), "point_cloud.ply")
213
+ # print("\n[ITER {}] Saving Gaussians to {}".format(iteration, point_cloud_path))
214
+ # scene.save(iteration)
215
+
216
+ # # Densification
217
+ # if iteration < opt.densify_until_iter:
218
+ # # Keep track of max radii in image-space for pruning
219
+ # gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
220
+ # gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
221
+
222
+ # if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
223
+ # size_threshold = 20 if iteration > opt.opacity_reset_interval else None
224
+ # gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
225
+
226
+ # if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
227
+ # gaussians.reset_opacity()
228
+
229
+ # # Optimizer step
230
+ # if iteration < opt.iterations:
231
+ # gaussians.optimizer.step()
232
+ # gaussians.optimizer.zero_grad(set_to_none = True)
233
+
234
+ # if (iteration == opt.iterations):
235
+ # print("\n[ITER {}] Saving Checkpoint".format(iteration))
236
+ # torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
237
 
238
 
239
  from os import makedirs