ostapagon commited on
Commit
1308f6d
·
1 Parent(s): f2a83d8

Add globals to avoid using inner declarations

Browse files
Files changed (3) hide show
  1. app.py +15 -21
  2. demo/demo_globals.py +19 -0
  3. demo/mast3r_demo.py +10 -9
app.py CHANGED
@@ -2,12 +2,9 @@ import sys
2
  sys.path.append('wild-gaussian-splatting/mast3r/')
3
  sys.path.append('demo/')
4
 
5
- import os
6
- import tempfile
7
  import gradio as gr
8
  import torch
9
  from mast3r.demo import get_args_parser
10
- from mast3r.utils.misc import hash_md5
11
  from mast3r_demo import mast3r_demo_tab
12
  # from gs_demo import gs_demo_tab
13
 
@@ -18,25 +15,22 @@ if __name__ == '__main__':
18
  # if args.server_name is not None:
19
  # server_name = args.server_name
20
  # else:
21
- server_name = '0.0.0.0'# if args.local_network else '127.0.0.1'
22
 
23
  # weights_path = '/app/wild-gaussian-splatting/mast3r/checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth'
24
- weights_path = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"#args.weights if args.weights is not None else + MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric
25
- device = device = 'cuda' if torch.cuda.is_available() else 'cpu'
26
- chkpt_tag = hash_md5(weights_path)
27
-
28
- with tempfile.TemporaryDirectory(suffix='demo') as tmpdirname:
29
- cache_path = os.path.join(tmpdirname, chkpt_tag)
30
- os.makedirs(cache_path, exist_ok=True)
31
-
32
- with gr.Blocks() as demo:
33
- with gr.Tabs():
34
- with gr.Tab("MASt3R Demo"):
35
- mast3r_demo_tab(cache_path, weights_path, device)
36
- # with gr.Tab("Gaussian Splatting Demo"):
37
- # gs_demo_tab(cache_path)
38
-
39
- demo.launch(show_error=True, share=None, server_name=None, server_port=None)
40
- # demo.launch(show_error=True, share=None, server_name='0.0.0.0', server_port=5555)
41
 
42
  # python3 demo.py --weights "/app/mast3r/checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth" --device "cuda" --server_port 3334 --local_network "$@"
 
2
  sys.path.append('wild-gaussian-splatting/mast3r/')
3
  sys.path.append('demo/')
4
 
 
 
5
  import gradio as gr
6
  import torch
7
  from mast3r.demo import get_args_parser
 
8
  from mast3r_demo import mast3r_demo_tab
9
  # from gs_demo import gs_demo_tab
10
 
 
15
  # if args.server_name is not None:
16
  # server_name = args.server_name
17
  # else:
18
+ # server_name = '0.0.0.0'# if args.local_network else '127.0.0.1'
19
 
20
  # weights_path = '/app/wild-gaussian-splatting/mast3r/checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth'
21
+ # weights_path = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"#args.weights if args.weights is not None else + MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric
22
+ # device = device = 'cuda' if torch.cuda.is_available() else 'cpu'
23
+ # chkpt_tag = hash_md5(weights_path)
24
+
25
+
26
+ with gr.Blocks() as demo:
27
+ with gr.Tabs():
28
+ with gr.Tab("MASt3R Demo"):
29
+ mast3r_demo_tab()
30
+ # with gr.Tab("Gaussian Splatting Demo"):
31
+ # gs_demo_tab(cache_path)
32
+
33
+ demo.launch(show_error=True, share=None, server_name=None, server_port=None)
34
+ # demo.launch(show_error=True, share=None, server_name='0.0.0.0', server_port=5555)
 
 
 
35
 
36
  # python3 demo.py --weights "/app/mast3r/checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth" --device "cuda" --server_port 3334 --local_network "$@"
demo/demo_globals.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('wild-gaussian-splatting/mast3r/')
3
+
4
+ import os
5
+ import tempfile
6
+ import torch
7
+ from mast3r.utils.misc import hash_md5
8
+ from mast3r.model import AsymmetricMASt3R
9
+
10
+ weights_path = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"
11
+ weights_path = '/app/wild-gaussian-splatting/mast3r/checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth'
12
+ tmpdirname = tempfile.TemporaryDirectory(suffix='demo')
13
+ chkpt_tag = hash_md5(weights_path)
14
+ CACHE_PATH = os.path.join(tmpdirname.name, chkpt_tag)
15
+ os.makedirs(CACHE_PATH, exist_ok=True)
16
+
17
+ DEVICE = device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+ MODEL = AsymmetricMASt3R.from_pretrained(weights_path).to(DEVICE)
19
+ SILENT = False
demo/mast3r_demo.py CHANGED
@@ -34,6 +34,8 @@ import matplotlib.pyplot as pl
34
  import torch
35
 
36
 
 
 
37
  class SparseGAState():
38
  def __init__(self, sparse_ga, cache_dir=None, outfile_name=None):
39
  self.sparse_ga = sparse_ga
@@ -266,17 +268,16 @@ def set_scenegraph_options(inputfiles, win_cyclic, refid, scenegraph_type):
266
  return win_col, winsize, win_cyclic, refid
267
 
268
 
269
- def mast3r_demo_tab(cache_path, weights_path, device, silent=False):
270
- model = AsymmetricMASt3R.from_pretrained(weights_path).to(device)
271
-
272
- if not silent:
273
- print('Outputing stuff in', cache_path)
274
 
275
- def get_reconstructed_scene_wrapper_func(*args, **kwargs):
276
- return get_reconstructed_scene(cache_path, model, device, silent, *args, **kwargs)
277
 
278
- def update_3D_model_from_scene(silent, *args, **kwargs):
279
- return get_3D_model_from_scene(silent, *args, **kwargs)
 
 
280
 
281
  def get_context():
282
  css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
 
34
  import torch
35
 
36
 
37
+ from demo_globals import CACHE_PATH, MODEL, DEVICE, SILENT
38
+
39
  class SparseGAState():
40
  def __init__(self, sparse_ga, cache_dir=None, outfile_name=None):
41
  self.sparse_ga = sparse_ga
 
268
  return win_col, winsize, win_cyclic, refid
269
 
270
 
271
+ def get_reconstructed_scene_wrapper_func(*args, **kwargs):
272
+ return get_reconstructed_scene(CACHE_PATH, MODEL, DEVICE, SILENT, *args, **kwargs)
 
 
 
273
 
274
+ def update_3D_model_from_scene(*args, **kwargs):
275
+ return get_3D_model_from_scene(SILENT, *args, **kwargs)
276
 
277
+ def mast3r_demo_tab():
278
+
279
+ if not SILENT:
280
+ print('Outputing stuff in', CACHE_PATH)
281
 
282
  def get_context():
283
  css = """.gradio-container {margin: 0 !important; min-width: 100%};"""