ostapagon commited on
Commit
665b2f0
·
1 Parent(s): d38e5ca

Add more fixes to demo files

Browse files
Files changed (4) hide show
  1. app.py +2 -0
  2. demo/gs_demo.py +2 -2
  3. demo/gs_train.py +3 -3
  4. demo/mast3r_demo.py +36 -36
app.py CHANGED
@@ -20,6 +20,7 @@ if __name__ == '__main__':
20
  # else:
21
  server_name = '0.0.0.0'# if args.local_network else '127.0.0.1'
22
 
 
23
  weights_path = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"#args.weights if args.weights is not None else + MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric
24
  device = device = 'cuda' if torch.cuda.is_available() else 'cpu'
25
  chkpt_tag = hash_md5(weights_path)
@@ -36,5 +37,6 @@ if __name__ == '__main__':
36
  gs_demo_tab(cache_path)
37
 
38
  demo.launch(show_error=True, share=None, server_name=None, server_port=None)
 
39
 
40
  # python3 demo.py --weights "/app/mast3r/checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth" --device "cuda" --server_port 3334 --local_network "$@"
 
20
  # else:
21
  server_name = '0.0.0.0'# if args.local_network else '127.0.0.1'
22
 
23
+ # weights_path = '/app/wild-gaussian-splatting/mast3r/checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth'
24
  weights_path = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"#args.weights if args.weights is not None else + MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric
25
  device = device = 'cuda' if torch.cuda.is_available() else 'cpu'
26
  chkpt_tag = hash_md5(weights_path)
 
37
  gs_demo_tab(cache_path)
38
 
39
  demo.launch(show_error=True, share=None, server_name=None, server_port=None)
40
+ # demo.launch(show_error=True, share=None, server_name='0.0.0.0', server_port=5555)
41
 
42
  # python3 demo.py --weights "/app/mast3r/checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth" --device "cuda" --server_port 3334 --local_network "$@"
demo/gs_demo.py CHANGED
@@ -38,7 +38,7 @@ def gs_demo_tab(cache_path):
38
  def update_dataset_dropdown():
39
  print("update_dataset_dropdown, cache_path", cache_path)
40
  # Update the dataset folders list
41
- dataset_folders = get_dataset_folders(datasets_path)
42
  # dataset_folders = "/app/data/scenes/"
43
  print("dataset_folders", dataset_folders)
44
  # Only set a default value if there are folders available
@@ -109,7 +109,7 @@ def gs_demo_tab(cache_path):
109
 
110
  def handle_training_complete(selected_folder, *args):
111
  # Construct the full path to the selected dataset
112
- selected_data_path = os.path.join(datasets_path, selected_folder)
113
  # Call the training function with the full path
114
  video_path, model_path = train(selected_data_path, *args)
115
  # Then return all required outputs
 
38
  def update_dataset_dropdown():
39
  print("update_dataset_dropdown, cache_path", cache_path)
40
  # Update the dataset folders list
41
+ dataset_folders = get_dataset_folders(dataset_path)
42
  # dataset_folders = "/app/data/scenes/"
43
  print("dataset_folders", dataset_folders)
44
  # Only set a default value if there are folders available
 
109
 
110
  def handle_training_complete(selected_folder, *args):
111
  # Construct the full path to the selected dataset
112
+ selected_data_path = os.path.join(dataset_path, selected_folder)
113
  # Call the training function with the full path
114
  video_path, model_path = train(selected_data_path, *args)
115
  # Then return all required outputs
demo/gs_train.py CHANGED
@@ -8,7 +8,7 @@ import gradio as gr
8
  import importlib.util
9
  from dataclasses import dataclass, field
10
 
11
- import spaces
12
 
13
 
14
  @dataclass
@@ -60,7 +60,7 @@ class TrainingArgs:
60
  checkpoint_iterations: list[int] = field(default_factory=lambda: [7_000, 15_000, 30_000])
61
  start_checkpoint: str = None
62
 
63
- @spaces.GPU(duration=20)
64
  def train(
65
  data_source_path, sh_degree, model_path, images, resolution, white_background, data_device, eval,
66
  convert_SHs_python, compute_cov3D_python, debug,
@@ -77,7 +77,7 @@ def train(
77
 
78
  # Import necessary modules from the gaussian-splatting directory
79
  from utils.loss_utils import l1_loss, ssim
80
- from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
81
  from scene import Scene, GaussianModel
82
  from utils.general_utils import safe_state
83
  from utils.image_utils import psnr
 
8
  import importlib.util
9
  from dataclasses import dataclass, field
10
 
11
+ # import spaces
12
 
13
 
14
  @dataclass
 
60
  checkpoint_iterations: list[int] = field(default_factory=lambda: [7_000, 15_000, 30_000])
61
  start_checkpoint: str = None
62
 
63
+ # @spaces.GPU(duration=20)
64
  def train(
65
  data_source_path, sh_degree, model_path, images, resolution, white_background, data_device, eval,
66
  convert_SHs_python, compute_cov3D_python, debug,
 
77
 
78
  # Import necessary modules from the gaussian-splatting directory
79
  from utils.loss_utils import l1_loss, ssim
80
+ from gaussian_renderer import render
81
  from scene import Scene, GaussianModel
82
  from utils.general_utils import safe_state
83
  from utils.image_utils import psnr
demo/mast3r_demo.py CHANGED
@@ -139,42 +139,42 @@ def get_3D_model_from_scene(silent, scene_state, min_conf_thr=2, as_pointcloud=F
139
  return _convert_scene_output_to_glb(outfile, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
140
  transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
141
 
142
- def save_colmap_scene(scene, save_dir, min_conf_thr=2, clean_depth=False):
143
- if 'save_pointcloud_with_normals' not in globals():
144
- sys.path.append(os.path.join(os.path.dirname(__file__), '../wild-gaussian-splatting/gaussian-splatting'))
145
- sys.path.append(os.path.join(os.path.dirname(__file__), '../wild-gaussian-splatting/src'))
146
- from colmap_dataset_utils import (
147
- inv,
148
- init_filestructure,
149
- save_images_masks,
150
- save_cameras,
151
- save_imagestxt,
152
- save_pointcloud,
153
- save_pointcloud_with_normals
154
- )
155
-
156
- cam2world = scene.get_im_poses().detach().cpu().numpy()
157
- world2cam = inv(cam2world) #
158
- principal_points = scene.get_principal_points().detach().cpu().numpy()
159
- focals = scene.get_focals().detach().cpu().numpy()[..., None]
160
- imgs = np.array(scene.imgs)
161
-
162
- pts3d, _, confs = scene.get_dense_pts3d(clean_depth=clean_depth)
163
- pts3d = [i.detach().reshape(imgs[0].shape) for i in pts3d] #
164
-
165
- masks = to_numpy([c > min_conf_thr for c in to_numpy(confs)])
166
-
167
- # move
168
- mask_images = True
169
-
170
- save_path, images_path, masks_path, sparse_path = init_filestructure(save_dir)
171
- save_images_masks(imgs, masks, images_path, masks_path, mask_images)
172
- save_cameras(focals, principal_points, sparse_path, imgs_shape=imgs.shape)
173
- save_imagestxt(world2cam, sparse_path)
174
- save_pointcloud_with_normals(imgs, pts3d, masks, sparse_path)
175
- return save_path
176
-
177
- @spaces.GPU(duration=20)
178
  def get_reconstructed_scene(outdir, model, device, silent, image_size, current_scene_state,
179
  filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
180
  as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
 
139
  return _convert_scene_output_to_glb(outfile, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
140
  transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
141
 
142
+ # def save_colmap_scene(scene, save_dir, min_conf_thr=2, clean_depth=False):
143
+ # if 'save_pointcloud_with_normals' not in globals():
144
+ # sys.path.append(os.path.join(os.path.dirname(__file__), '../wild-gaussian-splatting/gaussian-splatting'))
145
+ # sys.path.append(os.path.join(os.path.dirname(__file__), '../wild-gaussian-splatting/src'))
146
+ # from colmap_dataset_utils import (
147
+ # inv,
148
+ # init_filestructure,
149
+ # save_images_masks,
150
+ # save_cameras,
151
+ # save_imagestxt,
152
+ # save_pointcloud,
153
+ # save_pointcloud_with_normals
154
+ # )
155
+
156
+ # cam2world = scene.get_im_poses().detach().cpu().numpy()
157
+ # world2cam = inv(cam2world) #
158
+ # principal_points = scene.get_principal_points().detach().cpu().numpy()
159
+ # focals = scene.get_focals().detach().cpu().numpy()[..., None]
160
+ # imgs = np.array(scene.imgs)
161
+
162
+ # pts3d, _, confs = scene.get_dense_pts3d(clean_depth=clean_depth)
163
+ # pts3d = [i.detach().reshape(imgs[0].shape) for i in pts3d] #
164
+
165
+ # masks = to_numpy([c > min_conf_thr for c in to_numpy(confs)])
166
+
167
+ # # move
168
+ # mask_images = True
169
+
170
+ # save_path, images_path, masks_path, sparse_path = init_filestructure(save_dir)
171
+ # save_images_masks(imgs, masks, images_path, masks_path, mask_images)
172
+ # save_cameras(focals, principal_points, sparse_path, imgs_shape=imgs.shape)
173
+ # save_imagestxt(world2cam, sparse_path)
174
+ # save_pointcloud_with_normals(imgs, pts3d, masks, sparse_path)
175
+ # return save_path
176
+
177
+ @spaces.GPU(duration=10)
178
  def get_reconstructed_scene(outdir, model, device, silent, image_size, current_scene_state,
179
  filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
180
  as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,