Spaces:
Running
on
Zero
Running
on
Zero
Add more fixes to demo files
Browse files- app.py +2 -0
- demo/gs_demo.py +2 -2
- demo/gs_train.py +3 -3
- demo/mast3r_demo.py +36 -36
app.py
CHANGED
@@ -20,6 +20,7 @@ if __name__ == '__main__':
|
|
20 |
# else:
|
21 |
server_name = '0.0.0.0'# if args.local_network else '127.0.0.1'
|
22 |
|
|
|
23 |
weights_path = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"#args.weights if args.weights is not None else + MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric
|
24 |
device = device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
25 |
chkpt_tag = hash_md5(weights_path)
|
@@ -36,5 +37,6 @@ if __name__ == '__main__':
|
|
36 |
gs_demo_tab(cache_path)
|
37 |
|
38 |
demo.launch(show_error=True, share=None, server_name=None, server_port=None)
|
|
|
39 |
|
40 |
# python3 demo.py --weights "/app/mast3r/checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth" --device "cuda" --server_port 3334 --local_network "$@"
|
|
|
20 |
# else:
|
21 |
server_name = '0.0.0.0'# if args.local_network else '127.0.0.1'
|
22 |
|
23 |
+
# weights_path = '/app/wild-gaussian-splatting/mast3r/checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth'
|
24 |
weights_path = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"#args.weights if args.weights is not None else + MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric
|
25 |
device = device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
26 |
chkpt_tag = hash_md5(weights_path)
|
|
|
37 |
gs_demo_tab(cache_path)
|
38 |
|
39 |
demo.launch(show_error=True, share=None, server_name=None, server_port=None)
|
40 |
+
# demo.launch(show_error=True, share=None, server_name='0.0.0.0', server_port=5555)
|
41 |
|
42 |
# python3 demo.py --weights "/app/mast3r/checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth" --device "cuda" --server_port 3334 --local_network "$@"
|
demo/gs_demo.py
CHANGED
@@ -38,7 +38,7 @@ def gs_demo_tab(cache_path):
|
|
38 |
def update_dataset_dropdown():
|
39 |
print("update_dataset_dropdown, cache_path", cache_path)
|
40 |
# Update the dataset folders list
|
41 |
-
dataset_folders = get_dataset_folders(
|
42 |
# dataset_folders = "/app/data/scenes/"
|
43 |
print("dataset_folders", dataset_folders)
|
44 |
# Only set a default value if there are folders available
|
@@ -109,7 +109,7 @@ def gs_demo_tab(cache_path):
|
|
109 |
|
110 |
def handle_training_complete(selected_folder, *args):
|
111 |
# Construct the full path to the selected dataset
|
112 |
-
selected_data_path = os.path.join(
|
113 |
# Call the training function with the full path
|
114 |
video_path, model_path = train(selected_data_path, *args)
|
115 |
# Then return all required outputs
|
|
|
38 |
def update_dataset_dropdown():
|
39 |
print("update_dataset_dropdown, cache_path", cache_path)
|
40 |
# Update the dataset folders list
|
41 |
+
dataset_folders = get_dataset_folders(dataset_path)
|
42 |
# dataset_folders = "/app/data/scenes/"
|
43 |
print("dataset_folders", dataset_folders)
|
44 |
# Only set a default value if there are folders available
|
|
|
109 |
|
110 |
def handle_training_complete(selected_folder, *args):
|
111 |
# Construct the full path to the selected dataset
|
112 |
+
selected_data_path = os.path.join(dataset_path, selected_folder)
|
113 |
# Call the training function with the full path
|
114 |
video_path, model_path = train(selected_data_path, *args)
|
115 |
# Then return all required outputs
|
demo/gs_train.py
CHANGED
@@ -8,7 +8,7 @@ import gradio as gr
|
|
8 |
import importlib.util
|
9 |
from dataclasses import dataclass, field
|
10 |
|
11 |
-
import spaces
|
12 |
|
13 |
|
14 |
@dataclass
|
@@ -60,7 +60,7 @@ class TrainingArgs:
|
|
60 |
checkpoint_iterations: list[int] = field(default_factory=lambda: [7_000, 15_000, 30_000])
|
61 |
start_checkpoint: str = None
|
62 |
|
63 |
-
@spaces.GPU(duration=20)
|
64 |
def train(
|
65 |
data_source_path, sh_degree, model_path, images, resolution, white_background, data_device, eval,
|
66 |
convert_SHs_python, compute_cov3D_python, debug,
|
@@ -77,7 +77,7 @@ def train(
|
|
77 |
|
78 |
# Import necessary modules from the gaussian-splatting directory
|
79 |
from utils.loss_utils import l1_loss, ssim
|
80 |
-
from
|
81 |
from scene import Scene, GaussianModel
|
82 |
from utils.general_utils import safe_state
|
83 |
from utils.image_utils import psnr
|
|
|
8 |
import importlib.util
|
9 |
from dataclasses import dataclass, field
|
10 |
|
11 |
+
# import spaces
|
12 |
|
13 |
|
14 |
@dataclass
|
|
|
60 |
checkpoint_iterations: list[int] = field(default_factory=lambda: [7_000, 15_000, 30_000])
|
61 |
start_checkpoint: str = None
|
62 |
|
63 |
+
# @spaces.GPU(duration=20)
|
64 |
def train(
|
65 |
data_source_path, sh_degree, model_path, images, resolution, white_background, data_device, eval,
|
66 |
convert_SHs_python, compute_cov3D_python, debug,
|
|
|
77 |
|
78 |
# Import necessary modules from the gaussian-splatting directory
|
79 |
from utils.loss_utils import l1_loss, ssim
|
80 |
+
from gaussian_renderer import render
|
81 |
from scene import Scene, GaussianModel
|
82 |
from utils.general_utils import safe_state
|
83 |
from utils.image_utils import psnr
|
demo/mast3r_demo.py
CHANGED
@@ -139,42 +139,42 @@ def get_3D_model_from_scene(silent, scene_state, min_conf_thr=2, as_pointcloud=F
|
|
139 |
return _convert_scene_output_to_glb(outfile, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
|
140 |
transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
|
141 |
|
142 |
-
def save_colmap_scene(scene, save_dir, min_conf_thr=2, clean_depth=False):
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
@spaces.GPU(duration=
|
178 |
def get_reconstructed_scene(outdir, model, device, silent, image_size, current_scene_state,
|
179 |
filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
|
180 |
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
|
|
|
139 |
return _convert_scene_output_to_glb(outfile, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
|
140 |
transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
|
141 |
|
142 |
+
# def save_colmap_scene(scene, save_dir, min_conf_thr=2, clean_depth=False):
|
143 |
+
# if 'save_pointcloud_with_normals' not in globals():
|
144 |
+
# sys.path.append(os.path.join(os.path.dirname(__file__), '../wild-gaussian-splatting/gaussian-splatting'))
|
145 |
+
# sys.path.append(os.path.join(os.path.dirname(__file__), '../wild-gaussian-splatting/src'))
|
146 |
+
# from colmap_dataset_utils import (
|
147 |
+
# inv,
|
148 |
+
# init_filestructure,
|
149 |
+
# save_images_masks,
|
150 |
+
# save_cameras,
|
151 |
+
# save_imagestxt,
|
152 |
+
# save_pointcloud,
|
153 |
+
# save_pointcloud_with_normals
|
154 |
+
# )
|
155 |
+
|
156 |
+
# cam2world = scene.get_im_poses().detach().cpu().numpy()
|
157 |
+
# world2cam = inv(cam2world) #
|
158 |
+
# principal_points = scene.get_principal_points().detach().cpu().numpy()
|
159 |
+
# focals = scene.get_focals().detach().cpu().numpy()[..., None]
|
160 |
+
# imgs = np.array(scene.imgs)
|
161 |
+
|
162 |
+
# pts3d, _, confs = scene.get_dense_pts3d(clean_depth=clean_depth)
|
163 |
+
# pts3d = [i.detach().reshape(imgs[0].shape) for i in pts3d] #
|
164 |
+
|
165 |
+
# masks = to_numpy([c > min_conf_thr for c in to_numpy(confs)])
|
166 |
+
|
167 |
+
# # move
|
168 |
+
# mask_images = True
|
169 |
+
|
170 |
+
# save_path, images_path, masks_path, sparse_path = init_filestructure(save_dir)
|
171 |
+
# save_images_masks(imgs, masks, images_path, masks_path, mask_images)
|
172 |
+
# save_cameras(focals, principal_points, sparse_path, imgs_shape=imgs.shape)
|
173 |
+
# save_imagestxt(world2cam, sparse_path)
|
174 |
+
# save_pointcloud_with_normals(imgs, pts3d, masks, sparse_path)
|
175 |
+
# return save_path
|
176 |
+
|
177 |
+
@spaces.GPU(duration=10)
|
178 |
def get_reconstructed_scene(outdir, model, device, silent, image_size, current_scene_state,
|
179 |
filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
|
180 |
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
|