JianyuanWang commited on
Commit
e8227e4
Β·
1 Parent(s): e404aa3

update readme

Browse files
Files changed (4) hide show
  1. app.py +13 -12
  2. demo_gradio.py +23 -19
  3. demo_viser.py +15 -101
  4. gradio_util.py β†’ visual_util.py +7 -4
app.py CHANGED
@@ -15,25 +15,27 @@ from datetime import datetime
15
  import glob
16
  import gc
17
  import time
18
- import spaces
19
 
20
 
21
  sys.path.append("vggt/")
22
 
23
- from gradio_util import predictions_to_glb
24
  from vggt.models.vggt import VGGT
25
  from vggt.utils.load_fn import load_and_preprocess_images
26
  from vggt.utils.pose_enc import pose_encoding_to_extri_intri
27
  from vggt.utils.geometry import unproject_depth_map_to_point_map
28
 
 
29
 
30
  print("Initializing and loading VGGT model...")
31
  # model = VGGT.from_pretrained("facebook/VGGT-1B") # another way to load the model
32
 
33
- # device = "cuda" if torch.cuda.is_available() else "cpu"
34
  model = VGGT()
35
  _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
36
  model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
 
 
37
  model.eval()
38
  # model = model.to(device)
39
 
@@ -41,7 +43,7 @@ model.eval()
41
  # -------------------------------------------------------------------------
42
  # 1) Core model inference
43
  # -------------------------------------------------------------------------
44
- @spaces.GPU(duration=120)
45
  def run_model(target_dir, model) -> dict:
46
  """
47
  Run the VGGT model on images in the 'target_dir/images' folder and return predictions.
@@ -181,7 +183,7 @@ def update_gallery_on_upload(input_video, input_images):
181
  # -------------------------------------------------------------------------
182
  # 4) Reconstruction: uses the target_dir plus any viz parameters
183
  # -------------------------------------------------------------------------
184
- @spaces.GPU(duration=120)
185
  def gradio_demo(
186
  target_dir,
187
  conf_thres=3.0,
@@ -313,7 +315,7 @@ def update_visualization(
313
  # Example images
314
  # -------------------------------------------------------------------------
315
 
316
- canyon_video = "examples/videos/Studlagil_Canyon_East_Iceland.mp4"
317
  great_wall_video = "examples/videos/great_wall.mp4"
318
  colosseum_video = "examples/videos/Colosseum.mp4"
319
  room_video = "examples/videos/room.mp4"
@@ -392,9 +394,9 @@ with gr.Blocks(
392
 
393
  <h3>Getting Started:</h3>
394
  <ol>
395
- <li><strong>Upload Your Data:</strong> Use the β€œUpload Video” or β€œUpload Images” buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).</li>
396
  <li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
397
- <li><strong>Reconstruct:</strong> Click the β€œReconstruct” button to start the 3D reconstruction process.</li>
398
  <li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.</li>
399
  <li>
400
  <strong>Adjust Visualization (Optional):</strong>
@@ -406,17 +408,16 @@ with gr.Blocks(
406
  <li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
407
  <li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
408
  <li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
409
- <li><em>Select a Prediction Mode:</em> Choose between β€œDepthmap and Camera Branch” or β€œPointmap Branch.”</li>
410
  </ul>
411
  </details>
412
  </li>
413
  </ol>
414
- <p><strong>Please note:</strong> Our method usually only needs less than 1 second to reconstruct a scene, but the visualization of 3D points may take tens of seconds, especially when the number of images is large. Please be patient or, for faster visualization, use a local machine to run our demo from our <a href="https://github.com/facebookresearch/vggt">GitHub repository</a>.</p>
415
  </div>
416
  """
417
  )
418
 
419
-
420
  target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
421
 
422
  with gr.Row():
@@ -472,7 +473,7 @@ with gr.Blocks(
472
  [pyramid_video, "30", None, 35.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
473
  [single_cartoon_video, "1", None, 15.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
474
  [single_oil_painting_video, "1", None, 20.0, False, True, True, True, "Depthmap and Camera Branch", "True"],
475
- [canyon_video, "14", None, 40.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
476
  [room_video, "8", None, 5.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
477
  [kitchen_video, "25", None, 50.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
478
  [fern_video, "20", None, 45.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
 
15
  import glob
16
  import gc
17
  import time
18
+ # import spaces
19
 
20
 
21
  sys.path.append("vggt/")
22
 
23
+ from visual_util import predictions_to_glb
24
  from vggt.models.vggt import VGGT
25
  from vggt.utils.load_fn import load_and_preprocess_images
26
  from vggt.utils.pose_enc import pose_encoding_to_extri_intri
27
  from vggt.utils.geometry import unproject_depth_map_to_point_map
28
 
29
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
30
 
31
  print("Initializing and loading VGGT model...")
32
  # model = VGGT.from_pretrained("facebook/VGGT-1B") # another way to load the model
33
 
 
34
  model = VGGT()
35
  _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
36
  model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
37
+
38
+
39
  model.eval()
40
  # model = model.to(device)
41
 
 
43
  # -------------------------------------------------------------------------
44
  # 1) Core model inference
45
  # -------------------------------------------------------------------------
46
+ # @spaces.GPU(duration=120)
47
  def run_model(target_dir, model) -> dict:
48
  """
49
  Run the VGGT model on images in the 'target_dir/images' folder and return predictions.
 
183
  # -------------------------------------------------------------------------
184
  # 4) Reconstruction: uses the target_dir plus any viz parameters
185
  # -------------------------------------------------------------------------
186
+ # @spaces.GPU(duration=120)
187
  def gradio_demo(
188
  target_dir,
189
  conf_thres=3.0,
 
315
  # Example images
316
  # -------------------------------------------------------------------------
317
 
318
+ # canyon_video = "examples/videos/Studlagil_Canyon_East_Iceland.mp4"
319
  great_wall_video = "examples/videos/great_wall.mp4"
320
  colosseum_video = "examples/videos/Colosseum.mp4"
321
  room_video = "examples/videos/room.mp4"
 
394
 
395
  <h3>Getting Started:</h3>
396
  <ol>
397
+ <li><strong>Upload Your Data:</strong> Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).</li>
398
  <li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
399
+ <li><strong>Reconstruct:</strong> Click the "Reconstruct" button to start the 3D reconstruction process.</li>
400
  <li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.</li>
401
  <li>
402
  <strong>Adjust Visualization (Optional):</strong>
 
408
  <li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
409
  <li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
410
  <li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
411
+ <li><em>Select a Prediction Mode:</em> Choose between "Depthmap and Camera Branch" or "Pointmap Branch."</li>
412
  </ul>
413
  </details>
414
  </li>
415
  </ol>
416
+ <p><strong style="color: #0ea5e9;">Please note:</strong> <span style="color: #0ea5e9; font-weight: bold;">Our method usually only needs less than 1 second to reconstruct a scene, but the visualization of 3D points may take tens of seconds</span>, especially when the number of images is large. Please be patient or, for faster visualization, use a local machine to run our demo from our <a href="https://github.com/facebookresearch/vggt">GitHub repository</a>.</p>
417
  </div>
418
  """
419
  )
420
 
 
421
  target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
422
 
423
  with gr.Row():
 
473
  [pyramid_video, "30", None, 35.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
474
  [single_cartoon_video, "1", None, 15.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
475
  [single_oil_painting_video, "1", None, 20.0, False, True, True, True, "Depthmap and Camera Branch", "True"],
476
+ # [canyon_video, "14", None, 40.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
477
  [room_video, "8", None, 5.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
478
  [kitchen_video, "25", None, 50.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
479
  [fern_video, "20", None, 45.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
demo_gradio.py CHANGED
@@ -18,20 +18,22 @@ import time
18
 
19
  sys.path.append("vggt/")
20
 
21
- from gradio_util import predictions_to_glb
22
  from vggt.models.vggt import VGGT
23
  from vggt.utils.load_fn import load_and_preprocess_images
24
  from vggt.utils.pose_enc import pose_encoding_to_extri_intri
25
  from vggt.utils.geometry import unproject_depth_map_to_point_map
26
 
 
27
 
28
  print("Initializing and loading VGGT model...")
29
  # model = VGGT.from_pretrained("facebook/VGGT-1B") # another way to load the model
30
 
31
- device = "cuda" if torch.cuda.is_available() else "cpu"
32
  model = VGGT()
33
  _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
34
  model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
 
 
35
  model.eval()
36
  model = model.to(device)
37
 
@@ -375,35 +377,37 @@ with gr.Blocks(
375
  is_example = gr.Textbox(label="is_example", visible=False, value="None")
376
  num_images = gr.Textbox(label="num_images", visible=False, value="None")
377
 
378
- gr.Markdown(
379
- """
380
- # πŸ›οΈ VGGT: Visual Geometry Grounded Transformer
381
-
382
- [πŸ™ GitHub Repository](https://github.com/facebookresearch/vggt) | [Project Page]()
 
 
383
 
384
  <div style="font-size: 16px; line-height: 1.5;">
385
- <p>Upload a video or a set of images to create a 3D reconstruction of a scene or object. VGGT takes these images and generates a 3D point cloud, along with estimated camera poses.</p>
386
 
387
  <h3>Getting Started:</h3>
388
  <ol>
389
- <li><strong>Upload Your Data:</strong> Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).</li>
390
- <li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
391
- <li><strong>Reconstruct:</strong> Click the "Reconstruct" button to start the 3D reconstruction process.</li>
392
- <li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for large number of input images. </li>
393
- <li>
394
  <strong>Adjust Visualization (Optional):</strong>
395
- After reconstruction, you can fine-tune the visualization using the options below
396
  <details style="display:inline;">
397
- <summary style="display:inline;">(<strong>click to expand</strong>):</summary>
398
- <ul>
399
  <li><em>Confidence Threshold:</em> Adjust the filtering of points based on confidence.</li>
400
  <li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
401
  <li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
402
  <li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
403
- <li><em>Select a Prediction Mode:</em> Choose between "Depthmap and Camera Branch" or "Pointmap Branch."</li>
404
- </ul>
405
  </details>
406
- </li>
407
  </ol>
408
  <p><strong>Please note:</strong> Our method usually only needs less than 1 second to reconstruct a scene, but the visualization of 3D points may take tens of seconds, especially when the number of images is large. Please be patient or, for faster visualization, use a local machine to run our demo from our <a href="https://github.com/facebookresearch/vggt">GitHub repository</a>.</p>
409
  </div>
 
18
 
19
  sys.path.append("vggt/")
20
 
21
+ from visual_util import predictions_to_glb
22
  from vggt.models.vggt import VGGT
23
  from vggt.utils.load_fn import load_and_preprocess_images
24
  from vggt.utils.pose_enc import pose_encoding_to_extri_intri
25
  from vggt.utils.geometry import unproject_depth_map_to_point_map
26
 
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
 
29
  print("Initializing and loading VGGT model...")
30
  # model = VGGT.from_pretrained("facebook/VGGT-1B") # another way to load the model
31
 
 
32
  model = VGGT()
33
  _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
34
  model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
35
+
36
+
37
  model.eval()
38
  model = model.to(device)
39
 
 
377
  is_example = gr.Textbox(label="is_example", visible=False, value="None")
378
  num_images = gr.Textbox(label="num_images", visible=False, value="None")
379
 
380
+ gr.HTML(
381
+ """
382
+ <h1>πŸ›οΈ VGGT: Visual Geometry Grounded Transformer</h1>
383
+ <p>
384
+ <a href="https://github.com/facebookresearch/vggt">πŸ™ GitHub Repository</a> |
385
+ <a href="#">Project Page</a>
386
+ </p>
387
 
388
  <div style="font-size: 16px; line-height: 1.5;">
389
+ <p>Upload a video or a set of images to create a 3D reconstruction of a scene or object. VGGT takes these images and generates a 3D point cloud, along with estimated camera poses.</p>
390
 
391
  <h3>Getting Started:</h3>
392
  <ol>
393
+ <li><strong>Upload Your Data:</strong> Use the β€œUpload Video” or β€œUpload Images” buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).</li>
394
+ <li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
395
+ <li><strong>Reconstruct:</strong> Click the β€œReconstruct” button to start the 3D reconstruction process.</li>
396
+ <li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.</li>
397
+ <li>
398
  <strong>Adjust Visualization (Optional):</strong>
399
+ After reconstruction, you can fine-tune the visualization using the options below
400
  <details style="display:inline;">
401
+ <summary style="display:inline;">(<strong>click to expand</strong>):</summary>
402
+ <ul>
403
  <li><em>Confidence Threshold:</em> Adjust the filtering of points based on confidence.</li>
404
  <li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
405
  <li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
406
  <li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
407
+ <li><em>Select a Prediction Mode:</em> Choose between β€œDepthmap and Camera Branch” or β€œPointmap Branch.”</li>
408
+ </ul>
409
  </details>
410
+ </li>
411
  </ol>
412
  <p><strong>Please note:</strong> Our method usually only needs less than 1 second to reconstruct a scene, but the visualization of 3D points may take tens of seconds, especially when the number of images is large. Please be patient or, for faster visualization, use a local machine to run our demo from our <a href="https://github.com/facebookresearch/vggt">GitHub repository</a>.</p>
413
  </div>
demo_viser.py CHANGED
@@ -10,7 +10,6 @@ import time
10
  import threading
11
  import argparse
12
  from typing import List, Optional
13
- import copy
14
 
15
  import numpy as np
16
  import torch
@@ -18,12 +17,14 @@ from tqdm.auto import tqdm
18
  import viser
19
  import viser.transforms as viser_tf
20
  import cv2
21
- import requests
 
22
  try:
23
  import onnxruntime
24
  except ImportError:
25
  print("onnxruntime not found. Sky segmentation may not work.")
26
 
 
27
  from vggt.models.vggt import VGGT
28
  from vggt.utils.load_fn import load_and_preprocess_images
29
  from vggt.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map
@@ -95,7 +96,7 @@ def viser_wrapper(
95
  # Flatten
96
  points = world_points.reshape(-1, 3)
97
  colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8)
98
- conf = conf.reshape(-1)
99
 
100
  cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam) # shape (S, 4, 4) typically
101
  # For convenience, we store only (3,4) portion
@@ -132,13 +133,12 @@ def viser_wrapper(
132
 
133
  # Create the main point cloud handle
134
  # Compute the threshold value as the given percentile
135
- init_threshold_val = np.percentile(conf, init_conf_threshold)
136
- init_conf_mask = conf > init_threshold_val
137
  point_cloud = server.scene.add_point_cloud(
138
  name="viser_pcd",
139
  points=points_centered[init_conf_mask],
140
  colors=colors_flat[init_conf_mask],
141
- # point_size=0.0001,
142
  point_size=0.001,
143
  point_shape="circle",
144
  )
@@ -213,8 +213,11 @@ def viser_wrapper(
213
  """Update the point cloud based on current GUI selections."""
214
  # Here we compute the threshold value based on the current percentage
215
  current_percentage = gui_points_conf.value
216
- threshold_val = np.percentile(conf, current_percentage)
217
- conf_mask = conf > threshold_val
 
 
 
218
 
219
  if gui_frame_selector.value == "All":
220
  frame_mask = np.ones_like(conf_mask, dtype=bool)
@@ -264,30 +267,6 @@ def viser_wrapper(
264
 
265
  # Helper functions for sky segmentation
266
 
267
- def download_file_from_url(url, filename):
268
- """Downloads a file from a Hugging Face model repo, handling redirects."""
269
- try:
270
- # Get the redirect URL
271
- response = requests.get(url, allow_redirects=False)
272
- response.raise_for_status() # Raise HTTPError for bad requests (4xx or 5xx)
273
-
274
- if response.status_code == 302: # Expecting a redirect
275
- redirect_url = response.headers["Location"]
276
- response = requests.get(redirect_url, stream=True)
277
- response.raise_for_status()
278
- else:
279
- print(f"Unexpected status code: {response.status_code}")
280
- return
281
-
282
- with open(filename, "wb") as f:
283
- for chunk in response.iter_content(chunk_size=8192):
284
- f.write(chunk)
285
- print(f"Downloaded {filename} successfully.")
286
-
287
- except requests.exceptions.RequestException as e:
288
- print(f"Error downloading file: {e}")
289
-
290
-
291
 
292
  def apply_sky_segmentation(conf: np.ndarray, image_folder: str) -> np.ndarray:
293
  """
@@ -335,7 +314,7 @@ def apply_sky_segmentation(conf: np.ndarray, image_folder: str) -> np.ndarray:
335
  # Convert list to numpy array with shape SΓ—HΓ—W
336
  sky_mask_array = np.array(sky_mask_list)
337
  # Apply sky mask to confidence scores
338
- sky_mask_binary = (sky_mask_array > 0.01).astype(np.float32)
339
  conf = conf * sky_mask_binary
340
 
341
  print("Sky segmentation applied successfully")
@@ -343,73 +322,6 @@ def apply_sky_segmentation(conf: np.ndarray, image_folder: str) -> np.ndarray:
343
 
344
 
345
 
346
- def segment_sky(image_path, onnx_session, mask_filename=None):
347
- """
348
- Segments sky from an image using an ONNX model.
349
-
350
- Args:
351
- image_path: Path to input image
352
- onnx_session: ONNX runtime session with loaded model
353
- mask_filename: Path to save the output mask
354
-
355
- Returns:
356
- np.ndarray: Binary mask where 255 indicates non-sky regions
357
- """
358
- assert mask_filename is not None
359
- image = cv2.imread(image_path)
360
-
361
- result_map = run_skyseg(onnx_session, [320, 320], image)
362
- # resize the result_map to the original image size
363
- result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
364
-
365
- output_mask = np.zeros_like(result_map_original)
366
- output_mask[result_map_original < 1] = 1
367
- output_mask = output_mask.astype(np.uint8) * 255
368
- os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
369
- cv2.imwrite(mask_filename, output_mask)
370
- return output_mask
371
-
372
-
373
- def run_skyseg(onnx_session, input_size, image):
374
- """
375
- Runs sky segmentation inference using ONNX model.
376
-
377
- Args:
378
- onnx_session: ONNX runtime session
379
- input_size: Target size for model input (width, height)
380
- image: Input image in BGR format
381
-
382
- Returns:
383
- np.ndarray: Segmentation mask
384
- """
385
- # Pre process:Resize, BGR->RGB, Transpose, PyTorch standardization, float32 cast
386
- temp_image = copy.deepcopy(image)
387
- resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
388
- x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
389
- x = np.array(x, dtype=np.float32)
390
- mean = [0.485, 0.456, 0.406]
391
- std = [0.229, 0.224, 0.225]
392
- x = (x / 255 - mean) / std
393
- x = x.transpose(2, 0, 1)
394
- x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
395
-
396
- # Inference
397
- input_name = onnx_session.get_inputs()[0].name
398
- output_name = onnx_session.get_outputs()[0].name
399
- onnx_result = onnx_session.run([output_name], {input_name: x})
400
-
401
- # Post process
402
- onnx_result = np.array(onnx_result).squeeze()
403
- min_value = np.min(onnx_result)
404
- max_value = np.max(onnx_result)
405
- onnx_result = (onnx_result - min_value) / (max_value - min_value)
406
- onnx_result *= 255
407
- onnx_result = onnx_result.astype("uint8")
408
-
409
- return onnx_result
410
-
411
-
412
-
413
 
414
 
415
 
@@ -450,6 +362,8 @@ def main():
450
  print(f"Using device: {device}")
451
 
452
  print("Initializing and loading VGGT model...")
 
 
453
  model = VGGT()
454
  _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
455
  model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
@@ -503,4 +417,4 @@ def main():
503
 
504
 
505
  if __name__ == "__main__":
506
- main()
 
10
  import threading
11
  import argparse
12
  from typing import List, Optional
 
13
 
14
  import numpy as np
15
  import torch
 
17
  import viser
18
  import viser.transforms as viser_tf
19
  import cv2
20
+
21
+
22
  try:
23
  import onnxruntime
24
  except ImportError:
25
  print("onnxruntime not found. Sky segmentation may not work.")
26
 
27
+ from visual_util import segment_sky, download_file_from_url
28
  from vggt.models.vggt import VGGT
29
  from vggt.utils.load_fn import load_and_preprocess_images
30
  from vggt.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map
 
96
  # Flatten
97
  points = world_points.reshape(-1, 3)
98
  colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8)
99
+ conf_flat = conf.reshape(-1)
100
 
101
  cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam) # shape (S, 4, 4) typically
102
  # For convenience, we store only (3,4) portion
 
133
 
134
  # Create the main point cloud handle
135
  # Compute the threshold value as the given percentile
136
+ init_threshold_val = np.percentile(conf_flat, init_conf_threshold)
137
+ init_conf_mask = (conf_flat >= init_threshold_val) & (conf_flat > 0.1)
138
  point_cloud = server.scene.add_point_cloud(
139
  name="viser_pcd",
140
  points=points_centered[init_conf_mask],
141
  colors=colors_flat[init_conf_mask],
 
142
  point_size=0.001,
143
  point_shape="circle",
144
  )
 
213
  """Update the point cloud based on current GUI selections."""
214
  # Here we compute the threshold value based on the current percentage
215
  current_percentage = gui_points_conf.value
216
+ threshold_val = np.percentile(conf_flat, current_percentage)
217
+
218
+ print(f"Threshold absolute value: {threshold_val}, percentage: {current_percentage}%")
219
+
220
+ conf_mask = (conf_flat >= threshold_val) & (conf_flat > 1e-5)
221
 
222
  if gui_frame_selector.value == "All":
223
  frame_mask = np.ones_like(conf_mask, dtype=bool)
 
267
 
268
  # Helper functions for sky segmentation
269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  def apply_sky_segmentation(conf: np.ndarray, image_folder: str) -> np.ndarray:
272
  """
 
314
  # Convert list to numpy array with shape SΓ—HΓ—W
315
  sky_mask_array = np.array(sky_mask_list)
316
  # Apply sky mask to confidence scores
317
+ sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32)
318
  conf = conf * sky_mask_binary
319
 
320
  print("Sky segmentation applied successfully")
 
322
 
323
 
324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
 
327
 
 
362
  print(f"Using device: {device}")
363
 
364
  print("Initializing and loading VGGT model...")
365
+ # model = VGGT.from_pretrained("facebook/VGGT-1B")
366
+
367
  model = VGGT()
368
  _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
369
  model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
 
417
 
418
 
419
  if __name__ == "__main__":
420
+ main()
gradio_util.py β†’ visual_util.py RENAMED
@@ -131,7 +131,7 @@ def predictions_to_glb(
131
  sky_mask_array = np.array(sky_mask_list)
132
 
133
  # Apply sky mask to confidence scores
134
- sky_mask_binary = (sky_mask_array > 0.01).astype(np.float32)
135
  pred_world_points_conf = pred_world_points_conf * sky_mask_binary
136
 
137
  if selected_frame_idx is not None:
@@ -155,7 +155,7 @@ def predictions_to_glb(
155
  else:
156
  conf_threshold = np.percentile(conf, conf_thres)
157
 
158
- conf_mask = conf >= conf_threshold
159
 
160
  if mask_black_bg:
161
  black_bg_mask = colors_rgb.sum(axis=1) >= 16
@@ -370,6 +370,7 @@ def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
370
  def segment_sky(image_path, onnx_session, mask_filename=None):
371
  """
372
  Segments sky from an image using an ONNX model.
 
373
 
374
  Args:
375
  image_path: Path to input image
@@ -387,9 +388,11 @@ def segment_sky(image_path, onnx_session, mask_filename=None):
387
  # resize the result_map to the original image size
388
  result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
389
 
 
 
390
  output_mask = np.zeros_like(result_map_original)
391
- output_mask[result_map_original < 1] = 1
392
- output_mask = output_mask.astype(np.uint8) * 255
393
  os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
394
  cv2.imwrite(mask_filename, output_mask)
395
  return output_mask
 
131
  sky_mask_array = np.array(sky_mask_list)
132
 
133
  # Apply sky mask to confidence scores
134
+ sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32)
135
  pred_world_points_conf = pred_world_points_conf * sky_mask_binary
136
 
137
  if selected_frame_idx is not None:
 
155
  else:
156
  conf_threshold = np.percentile(conf, conf_thres)
157
 
158
+ conf_mask = (conf >= conf_threshold) & (conf > 1e-5)
159
 
160
  if mask_black_bg:
161
  black_bg_mask = colors_rgb.sum(axis=1) >= 16
 
370
  def segment_sky(image_path, onnx_session, mask_filename=None):
371
  """
372
  Segments sky from an image using an ONNX model.
373
+ Thanks for the great model provided by https://github.com/xiongzhu666/Sky-Segmentation-and-Post-processing
374
 
375
  Args:
376
  image_path: Path to input image
 
388
  # resize the result_map to the original image size
389
  result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
390
 
391
+ # Fix: Invert the mask so that 255 = non-sky, 0 = sky
392
+ # The model outputs low values for sky, high values for non-sky
393
  output_mask = np.zeros_like(result_map_original)
394
+ output_mask[result_map_original < 32] = 255 # Use threshold of 32
395
+
396
  os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
397
  cv2.imwrite(mask_filename, output_mask)
398
  return output_mask