ostapagon commited on
Commit
eb89bfc
·
1 Parent(s): 9652ffd

Add lazy caching for the examples

Browse files
Files changed (4) hide show
  1. app.py +4 -3
  2. demo/demo_globals.py +2 -0
  3. demo/gs_demo.py +13 -10
  4. demo/mast3r_demo.py +56 -35
app.py CHANGED
@@ -22,7 +22,7 @@ def end_session(req: gr.Request):
22
  if __name__ == '__main__':
23
  with gr.Blocks() as demo:
24
  gr.HTML('''
25
- <div style="text-align: center; padding: 20px; background-color: #f9f9f9; border-radius: 10px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);">
26
  <h2>MASt3R and 3DGS Pipeline Demo</h2>
27
  <p style="font-size: 16px;">This pipeline is designed for 3D reconstruction using MASt3R and 3DGS.</p>
28
  <p style="font-size: 16px;">The process is divided into two stages:</p>
@@ -30,7 +30,7 @@ if __name__ == '__main__':
30
  <li>MASt3R is used to obtain the initial point cloud and camera parameters.</li>
31
  <li>3DGS is then trained on the results from MASt3R to refine the 3D scene representation.</li>
32
  </ol>
33
- <p style="font-size: 16px;">Note: After a page reload, any generated MASt3R datasets in the 3DGS tab will be deleted.</p>
34
  <p style="font-size: 16px;">For a full version of this pipeline, please visit the repository at:</p>
35
  <a href="https://github.com/nerlfield/wild-gaussian-splatting" target="_blank" style="font-size: 16px; text-decoration: none;">nerlfield/wild-gaussian-splatting</a>
36
  </div>
@@ -45,4 +45,5 @@ if __name__ == '__main__':
45
  demo.load(start_session)
46
  demo.unload(end_session)
47
 
48
- demo.launch(show_error=True, share=None, server_name=None, server_port=None)
 
 
22
  if __name__ == '__main__':
23
  with gr.Blocks() as demo:
24
  gr.HTML('''
25
+ <div style="text-align: center; padding: 20px; color: #333;">
26
  <h2>MASt3R and 3DGS Pipeline Demo</h2>
27
  <p style="font-size: 16px;">This pipeline is designed for 3D reconstruction using MASt3R and 3DGS.</p>
28
  <p style="font-size: 16px;">The process is divided into two stages:</p>
 
30
  <li>MASt3R is used to obtain the initial point cloud and camera parameters.</li>
31
  <li>3DGS is then trained on the results from MASt3R to refine the 3D scene representation.</li>
32
  </ol>
33
+ <p style="font-size: 16px; font-weight: bold;">Note: After a page reload, any generated MASt3R datasets in the 3DGS tab will be deleted.</p>
34
  <p style="font-size: 16px;">For a full version of this pipeline, please visit the repository at:</p>
35
  <a href="https://github.com/nerlfield/wild-gaussian-splatting" target="_blank" style="font-size: 16px; text-decoration: none;">nerlfield/wild-gaussian-splatting</a>
36
  </div>
 
45
  demo.load(start_session)
46
  demo.unload(end_session)
47
 
48
+ demo.launch(show_error=True, share=None, server_name='0.0.0.0', server_port=8111)
49
+ # demo.launch(show_error=True, share=None, server_name=None, server_port=None)
demo/demo_globals.py CHANGED
@@ -12,6 +12,8 @@ weights_path = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"
12
 
13
  CACHE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
14
  os.makedirs(CACHE_PATH, exist_ok=True)
 
 
15
 
16
  DEVICE = device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
  MODEL = AsymmetricMASt3R.from_pretrained(weights_path).to(DEVICE)
 
12
 
13
  CACHE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
14
  os.makedirs(CACHE_PATH, exist_ok=True)
15
+ EXAMPLE_PATH = os.path.join(CACHE_PATH, 'examples_datasets')
16
+ os.makedirs(EXAMPLE_PATH, exist_ok=True)
17
 
18
  DEVICE = device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
  MODEL = AsymmetricMASt3R.from_pretrained(weights_path).to(DEVICE)
demo/gs_demo.py CHANGED
@@ -2,14 +2,16 @@ import gradio as gr
2
  from gs_train import train
3
  import os
4
 
5
- from demo_globals import CACHE_PATH, MODEL, DEVICE, SILENT, DATASET_DIR
6
 
7
  def get_dataset_folders(datasets_path):
8
- try:
9
- folders = [f for f in os.listdir(datasets_path) if os.path.isdir(os.path.join(datasets_path, f))]
10
- return sorted(folders, key=lambda x: int(x.split('_')[-1]) if x.split('_')[-1].isdigit() else float('inf'))
11
- except FileNotFoundError:
12
- return []
 
 
13
 
14
  def gs_demo_tab():
15
  # datasets_path = "/app/data/scenes/"
@@ -112,11 +114,12 @@ def gs_demo_tab():
112
 
113
  def handle_training_complete(selected_folder, req: gr.Request, *args):
114
  USER_DIR = os.path.join(CACHE_PATH, str(req.session_hash))
115
- dataset_path = os.path.join(USER_DIR, DATASET_DIR)
116
- # Construct the full path to the selected dataset
117
- selected_data_path = os.path.join(dataset_path, selected_folder)
 
118
  # Call the training function with the full path
119
- video_path, model_path = train(selected_data_path, *args)
120
  # Then return all required outputs
121
  return [
122
  video_path, # video output
 
2
  from gs_train import train
3
  import os
4
 
5
+ from demo_globals import CACHE_PATH, EXAMPLE_PATH, MODEL, DEVICE, SILENT, DATASET_DIR
6
 
7
  def get_dataset_folders(datasets_path):
8
+
9
+ folder = []
10
+ if os.path.isdir(datasets_path):
11
+ folder += [f for f in os.listdir(datasets_path) if os.path.isdir(os.path.join(datasets_path, f))]
12
+ if os.path.isdir(EXAMPLE_PATH):
13
+ folder += [f for f in os.listdir(EXAMPLE_PATH) if os.path.isdir(os.path.join(EXAMPLE_PATH, f))]
14
+ return sorted(folder, key=lambda x: int(x.split('_')[-1]) if x.split('_')[-1].isdigit() else float('inf'))
15
 
16
  def gs_demo_tab():
17
  # datasets_path = "/app/data/scenes/"
 
114
 
115
  def handle_training_complete(selected_folder, req: gr.Request, *args):
116
  USER_DIR = os.path.join(CACHE_PATH, str(req.session_hash))
117
+ if 'run' in selected_folder:
118
+ dataset_path = os.path.join(USER_DIR, DATASET_DIR, selected_folder)
119
+ else:
120
+ dataset_path = os.path.join(EXAMPLE_PATH, selected_folder)
121
  # Call the training function with the full path
122
+ video_path, model_path = train(dataset_path, *args)
123
  # Then return all required outputs
124
  return [
125
  video_path, # video output
demo/mast3r_demo.py CHANGED
@@ -18,6 +18,7 @@ import copy
18
  from scipy.spatial.transform import Rotation
19
  import tempfile
20
  import shutil
 
21
 
22
  from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
23
  from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
@@ -36,7 +37,7 @@ import torch
36
  import os.path as path
37
  HERE_PATH = path.normpath(path.dirname(__file__)) # noqa
38
 
39
- from demo_globals import CACHE_PATH, MODEL, DEVICE, SILENT, DATASET_DIR
40
 
41
  class SparseGAState():
42
  def __init__(self, cache_dir=None, outfile_name=None):
@@ -112,16 +113,11 @@ def _convert_scene_output_to_glb(outfile, imgs, pts3d, mask, focals, cams2world,
112
  return outfile
113
 
114
 
115
- def get_3D_model_from_scene(scene, scene_state, min_conf_thr=2, as_pointcloud=False, mask_sky=False,
116
  clean_depth=False, transparent_cams=False, cam_size=0.05, TSDF_thresh=0):
117
  """
118
  extract 3D_model (glb file) from a reconstructed scene
119
  """
120
- if scene_state is None:
121
- return None
122
- outfile = scene_state.outfile_name
123
- if outfile is None:
124
- return None
125
 
126
  # # get optimized values from scene
127
  # scene = scenescene_state.sparse_ga
@@ -174,14 +170,20 @@ def save_colmap_scene(scene, save_dir, min_conf_thr=2, clean_depth=False, mask_i
174
  return save_path
175
 
176
  @spaces.GPU(duration=160)
177
- def get_reconstructed_scene(snapshot, current_scene_state,
178
  min_conf_thr, matching_conf_thr,
179
- as_pointcloud, cam_size, shared_intrinsics, clean_depth, filelist, req: gradio.Request, **kw):
180
  """
181
  from a list of images, run mast3r inference, sparse global aligner.
182
  then run get_3D_model_from_scene
183
  """
184
- USER_DIR = os.path.join(CACHE_PATH, str(req.session_hash))
 
 
 
 
 
 
185
  image_size = 512
186
  imgs = load_images(filelist, size=image_size, verbose=not SILENT)
187
  if len(imgs) == 1:
@@ -236,26 +238,25 @@ def get_reconstructed_scene(snapshot, current_scene_state,
236
  opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
237
  matching_conf_thr=matching_conf_thr, **kw)
238
 
239
- base_colmapdata_dir = os.path.join(USER_DIR, DATASET_DIR)
240
- os.makedirs(base_colmapdata_dir, exist_ok=True)
241
- colmap_data_dir = get_next_dir(base_colmapdata_dir)
242
- #
243
- save_colmap_scene(scene, colmap_data_dir, min_conf_thr, clean_depth)
244
 
245
- if current_scene_state is not None and \
246
- current_scene_state.outfile_name is not None:
247
- outfile_name = current_scene_state.outfile_name
248
  else:
249
- outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=USER_DIR)
 
 
 
 
250
 
251
- scene_state = SparseGAState(cache_dir, outfile_name)
252
- outfile = get_3D_model_from_scene(scene, scene_state, min_conf_thr, as_pointcloud, mask_sky,
 
253
  clean_depth, transparent_cams, cam_size, TSDF_thresh)
254
  print(f"colmap_data_dir: {colmap_data_dir}")
255
  print(f"outfile_name: {outfile_name}")
256
  print(f"cache_dir: {cache_dir}")
257
  torch.cuda.empty_cache()
258
- return scene_state, outfile
259
 
260
 
261
  def mast3r_demo_tab():
@@ -285,6 +286,12 @@ def mast3r_demo_tab():
285
  snapshot = gradio.Image(None, visible=False)
286
  run_btn = gradio.Button("Run")
287
 
 
 
 
 
 
 
288
  with gradio.Row():
289
  matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=2.,
290
  minimum=0., maximum=30., step=0.1,
@@ -300,9 +307,9 @@ def mast3r_demo_tab():
300
  outmodel = gradio.Model3D()
301
  run_btn.click(
302
  fn=get_reconstructed_scene,
303
- inputs=[snapshot, scene, min_conf_thr, matching_conf_thr,
304
- as_pointcloud, cam_size, shared_intrinsics, clean_depth, inputfiles],
305
- outputs=[scene, outmodel]
306
  )
307
 
308
  tower_folder = os.path.join(HERE_PATH, '../wild-gaussian-splatting/mast3r/assets/NLE_tower/')
@@ -315,35 +322,49 @@ def mast3r_demo_tab():
315
  examples = gradio.Examples(
316
  examples=[
317
  [
318
- turtle_images[0],
319
- None,
320
  1.5, 0.0, 0.2, True, True, False,
321
- turtle_images,
 
 
322
  ]
323
  ],
324
- inputs=[snapshot, scene, min_conf_thr, matching_conf_thr, cam_size, as_pointcloud, shared_intrinsics, clean_depth, inputfiles],
 
 
 
 
325
  )
326
  examples = gradio.Examples(
327
  examples=[
328
  [
329
- puma_images[0],
330
- None,
331
  1.5, 0.0, 0.2, True, True, False,
332
- puma_images,
 
 
333
  ]
334
  ],
335
- inputs=[snapshot, scene, min_conf_thr, matching_conf_thr, cam_size, as_pointcloud, shared_intrinsics, clean_depth, inputfiles],
 
 
 
 
336
  )
337
  examples = gradio.Examples(
338
  examples=[
339
  [
340
  tower_images[0],
341
- None,
342
  1.5, 0.0, 0.2, True, False, False,
343
  tower_images,
 
344
  ]
345
  ],
346
- inputs=[snapshot, scene, min_conf_thr, matching_conf_thr, cam_size, as_pointcloud, shared_intrinsics, clean_depth, inputfiles],
 
 
 
 
347
  )
348
 
349
 
 
18
  from scipy.spatial.transform import Rotation
19
  import tempfile
20
  import shutil
21
+ import typing
22
 
23
  from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
24
  from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
 
37
  import os.path as path
38
  HERE_PATH = path.normpath(path.dirname(__file__)) # noqa
39
 
40
+ from demo_globals import CACHE_PATH, EXAMPLE_PATH, MODEL, DEVICE, SILENT, DATASET_DIR
41
 
42
  class SparseGAState():
43
  def __init__(self, cache_dir=None, outfile_name=None):
 
113
  return outfile
114
 
115
 
116
+ def get_3D_model_from_scene(scene, outfile, min_conf_thr=2, as_pointcloud=False, mask_sky=False,
117
  clean_depth=False, transparent_cams=False, cam_size=0.05, TSDF_thresh=0):
118
  """
119
  extract 3D_model (glb file) from a reconstructed scene
120
  """
 
 
 
 
 
121
 
122
  # # get optimized values from scene
123
  # scene = scenescene_state.sparse_ga
 
170
  return save_path
171
 
172
  @spaces.GPU(duration=160)
173
+ def get_reconstructed_scene(snapshot,
174
  min_conf_thr, matching_conf_thr,
175
+ as_pointcloud, cam_size, shared_intrinsics, clean_depth, filelist, example_name, req: gradio.Request, **kw):
176
  """
177
  from a list of images, run mast3r inference, sparse global aligner.
178
  then run get_3D_model_from_scene
179
  """
180
+
181
+ if example_name != '':
182
+ USER_DIR = os.path.join(CACHE_PATH, example_name)
183
+ else:
184
+ USER_DIR = os.path.join(CACHE_PATH, str(req.session_hash))
185
+ os.makedirs(USER_DIR, exist_ok=True)
186
+
187
  image_size = 512
188
  imgs = load_images(filelist, size=image_size, verbose=not SILENT)
189
  if len(imgs) == 1:
 
238
  opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
239
  matching_conf_thr=matching_conf_thr, **kw)
240
 
 
 
 
 
 
241
 
242
+ if example_name:
243
+ colmap_data_dir = os.path.join(EXAMPLE_PATH, example_name)
 
244
  else:
245
+ colmap_data_dir = get_next_dir(os.path.join(USER_DIR, DATASET_DIR))
246
+ os.makedirs(colmap_data_dir, exist_ok=True)
247
+
248
+
249
+ save_colmap_scene(scene, colmap_data_dir, min_conf_thr, clean_depth)
250
 
251
+ outfile_name = os.path.join(USER_DIR, 'default_scene.glb')
252
+
253
+ outfile = get_3D_model_from_scene(scene, outfile_name, min_conf_thr, as_pointcloud, mask_sky,
254
  clean_depth, transparent_cams, cam_size, TSDF_thresh)
255
  print(f"colmap_data_dir: {colmap_data_dir}")
256
  print(f"outfile_name: {outfile_name}")
257
  print(f"cache_dir: {cache_dir}")
258
  torch.cuda.empty_cache()
259
+ return outfile
260
 
261
 
262
  def mast3r_demo_tab():
 
286
  snapshot = gradio.Image(None, visible=False)
287
  run_btn = gradio.Button("Run")
288
 
289
+ dummy_req = gradio.Request()
290
+ dummy_text = gradio.Textbox(value="", visible=False)
291
+
292
+ example_name = gradio.Textbox(value="", visible=False)
293
+
294
+
295
  with gradio.Row():
296
  matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=2.,
297
  minimum=0., maximum=30., step=0.1,
 
307
  outmodel = gradio.Model3D()
308
  run_btn.click(
309
  fn=get_reconstructed_scene,
310
+ inputs=[snapshot, min_conf_thr, matching_conf_thr,
311
+ as_pointcloud, cam_size, shared_intrinsics, clean_depth, inputfiles, dummy_text],
312
+ outputs=[outmodel]
313
  )
314
 
315
  tower_folder = os.path.join(HERE_PATH, '../wild-gaussian-splatting/mast3r/assets/NLE_tower/')
 
322
  examples = gradio.Examples(
323
  examples=[
324
  [
325
+ puma_images[0],
 
326
  1.5, 0.0, 0.2, True, True, False,
327
+ puma_images,
328
+ 'puma',
329
+ None,
330
  ]
331
  ],
332
+ inputs=[snapshot, min_conf_thr, matching_conf_thr, cam_size, as_pointcloud, shared_intrinsics, clean_depth, inputfiles, example_name],
333
+ fn=get_reconstructed_scene,
334
+ outputs=[outmodel],
335
+ run_on_click=True,
336
+ cache_examples='lazy',
337
  )
338
  examples = gradio.Examples(
339
  examples=[
340
  [
341
+ turtle_images[0],
 
342
  1.5, 0.0, 0.2, True, True, False,
343
+ turtle_images,
344
+ 'turtle',
345
+ None
346
  ]
347
  ],
348
+ inputs=[snapshot, min_conf_thr, matching_conf_thr, cam_size, as_pointcloud, shared_intrinsics, clean_depth, inputfiles, example_name],
349
+ fn=get_reconstructed_scene,
350
+ outputs=[outmodel],
351
+ run_on_click=True,
352
+ cache_examples='lazy',
353
  )
354
  examples = gradio.Examples(
355
  examples=[
356
  [
357
  tower_images[0],
 
358
  1.5, 0.0, 0.2, True, False, False,
359
  tower_images,
360
+ 'tower',
361
  ]
362
  ],
363
+ inputs=[snapshot, min_conf_thr, matching_conf_thr, cam_size, as_pointcloud, shared_intrinsics, clean_depth, inputfiles, example_name],
364
+ fn=get_reconstructed_scene,
365
+ outputs=[outmodel],
366
+ run_on_click=True,
367
+ cache_examples='lazy',
368
  )
369
 
370