ostapagon commited on
Commit
4162035
·
1 Parent(s): 0c0ba9f

Increase time, add test line to check if cuda() calls are legal

Browse files
Files changed (2) hide show
  1. demo/gs_train.py +2 -2
  2. demo/mast3r_demo.py +2 -0
demo/gs_train.py CHANGED
@@ -60,7 +60,7 @@ class TrainingArgs:
60
  checkpoint_iterations: list[int] = field(default_factory=lambda: [7_000, 15_000, 30_000])
61
  start_checkpoint: str = None
62
 
63
- @spaces.GPU(duration=10)
64
  def train(
65
  data_source_path, sh_degree, model_path, images, resolution, white_background, data_device, eval,
66
  convert_SHs_python, compute_cov3D_python, debug,
@@ -241,7 +241,7 @@ def train(
241
  iteration = scene.loaded_iter
242
 
243
  bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
244
- background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
245
 
246
  model_path = dataset.model_path
247
  name = "render"
 
60
  checkpoint_iterations: list[int] = field(default_factory=lambda: [7_000, 15_000, 30_000])
61
  start_checkpoint: str = None
62
 
63
+ @spaces.GPU(duration=90)
64
  def train(
65
  data_source_path, sh_degree, model_path, images, resolution, white_background, data_device, eval,
66
  convert_SHs_python, compute_cov3D_python, debug,
 
241
  iteration = scene.loaded_iter
242
 
243
  bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
244
+ background = torch.tensor(bg_color, dtype=torch.float32, device=DEVICE)
245
 
246
  model_path = dataset.model_path
247
  name = "render"
demo/mast3r_demo.py CHANGED
@@ -213,6 +213,8 @@ def get_reconstructed_scene(image_size, current_scene_state,
213
  run_counter += 1
214
  return run_cache_dir
215
 
 
 
216
  cache_dir = get_next_dir(base_cache_dir)
217
  scene = sparse_global_alignment(filelist, pairs, cache_dir,
218
  MODEL, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=DEVICE,
 
213
  run_counter += 1
214
  return run_cache_dir
215
 
216
+
217
+ ten = torch.zeros((1024)).cuda()
218
  cache_dir = get_next_dir(base_cache_dir)
219
  scene = sparse_global_alignment(filelist, pairs, cache_dir,
220
  MODEL, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=DEVICE,