Spaces:
Running
on
Zero
Running
on
Zero
Commit
Β·
e8227e4
1
Parent(s):
e404aa3
update readme
Browse files- app.py +13 -12
- demo_gradio.py +23 -19
- demo_viser.py +15 -101
- gradio_util.py β visual_util.py +7 -4
app.py
CHANGED
@@ -15,25 +15,27 @@ from datetime import datetime
|
|
15 |
import glob
|
16 |
import gc
|
17 |
import time
|
18 |
-
import spaces
|
19 |
|
20 |
|
21 |
sys.path.append("vggt/")
|
22 |
|
23 |
-
from
|
24 |
from vggt.models.vggt import VGGT
|
25 |
from vggt.utils.load_fn import load_and_preprocess_images
|
26 |
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
27 |
from vggt.utils.geometry import unproject_depth_map_to_point_map
|
28 |
|
|
|
29 |
|
30 |
print("Initializing and loading VGGT model...")
|
31 |
# model = VGGT.from_pretrained("facebook/VGGT-1B") # another way to load the model
|
32 |
|
33 |
-
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
34 |
model = VGGT()
|
35 |
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
|
36 |
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
|
|
|
|
|
37 |
model.eval()
|
38 |
# model = model.to(device)
|
39 |
|
@@ -41,7 +43,7 @@ model.eval()
|
|
41 |
# -------------------------------------------------------------------------
|
42 |
# 1) Core model inference
|
43 |
# -------------------------------------------------------------------------
|
44 |
-
@spaces.GPU(duration=120)
|
45 |
def run_model(target_dir, model) -> dict:
|
46 |
"""
|
47 |
Run the VGGT model on images in the 'target_dir/images' folder and return predictions.
|
@@ -181,7 +183,7 @@ def update_gallery_on_upload(input_video, input_images):
|
|
181 |
# -------------------------------------------------------------------------
|
182 |
# 4) Reconstruction: uses the target_dir plus any viz parameters
|
183 |
# -------------------------------------------------------------------------
|
184 |
-
@spaces.GPU(duration=120)
|
185 |
def gradio_demo(
|
186 |
target_dir,
|
187 |
conf_thres=3.0,
|
@@ -313,7 +315,7 @@ def update_visualization(
|
|
313 |
# Example images
|
314 |
# -------------------------------------------------------------------------
|
315 |
|
316 |
-
canyon_video = "examples/videos/Studlagil_Canyon_East_Iceland.mp4"
|
317 |
great_wall_video = "examples/videos/great_wall.mp4"
|
318 |
colosseum_video = "examples/videos/Colosseum.mp4"
|
319 |
room_video = "examples/videos/room.mp4"
|
@@ -392,9 +394,9 @@ with gr.Blocks(
|
|
392 |
|
393 |
<h3>Getting Started:</h3>
|
394 |
<ol>
|
395 |
-
<li><strong>Upload Your Data:</strong> Use the
|
396 |
<li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
|
397 |
-
<li><strong>Reconstruct:</strong> Click the
|
398 |
<li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.</li>
|
399 |
<li>
|
400 |
<strong>Adjust Visualization (Optional):</strong>
|
@@ -406,17 +408,16 @@ with gr.Blocks(
|
|
406 |
<li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
|
407 |
<li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
|
408 |
<li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
|
409 |
-
<li><em>Select a Prediction Mode:</em> Choose between
|
410 |
</ul>
|
411 |
</details>
|
412 |
</li>
|
413 |
</ol>
|
414 |
-
<p><strong>Please note:</strong> Our method usually only needs less than 1 second to reconstruct a scene, but the visualization of 3D points may take tens of seconds
|
415 |
</div>
|
416 |
"""
|
417 |
)
|
418 |
|
419 |
-
|
420 |
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
|
421 |
|
422 |
with gr.Row():
|
@@ -472,7 +473,7 @@ with gr.Blocks(
|
|
472 |
[pyramid_video, "30", None, 35.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
|
473 |
[single_cartoon_video, "1", None, 15.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
|
474 |
[single_oil_painting_video, "1", None, 20.0, False, True, True, True, "Depthmap and Camera Branch", "True"],
|
475 |
-
[canyon_video, "14", None, 40.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
|
476 |
[room_video, "8", None, 5.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
|
477 |
[kitchen_video, "25", None, 50.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
|
478 |
[fern_video, "20", None, 45.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
|
|
|
15 |
import glob
|
16 |
import gc
|
17 |
import time
|
18 |
+
# import spaces
|
19 |
|
20 |
|
21 |
sys.path.append("vggt/")
|
22 |
|
23 |
+
from visual_util import predictions_to_glb
|
24 |
from vggt.models.vggt import VGGT
|
25 |
from vggt.utils.load_fn import load_and_preprocess_images
|
26 |
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
27 |
from vggt.utils.geometry import unproject_depth_map_to_point_map
|
28 |
|
29 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
30 |
|
31 |
print("Initializing and loading VGGT model...")
|
32 |
# model = VGGT.from_pretrained("facebook/VGGT-1B") # another way to load the model
|
33 |
|
|
|
34 |
model = VGGT()
|
35 |
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
|
36 |
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
|
37 |
+
|
38 |
+
|
39 |
model.eval()
|
40 |
# model = model.to(device)
|
41 |
|
|
|
43 |
# -------------------------------------------------------------------------
|
44 |
# 1) Core model inference
|
45 |
# -------------------------------------------------------------------------
|
46 |
+
# @spaces.GPU(duration=120)
|
47 |
def run_model(target_dir, model) -> dict:
|
48 |
"""
|
49 |
Run the VGGT model on images in the 'target_dir/images' folder and return predictions.
|
|
|
183 |
# -------------------------------------------------------------------------
|
184 |
# 4) Reconstruction: uses the target_dir plus any viz parameters
|
185 |
# -------------------------------------------------------------------------
|
186 |
+
# @spaces.GPU(duration=120)
|
187 |
def gradio_demo(
|
188 |
target_dir,
|
189 |
conf_thres=3.0,
|
|
|
315 |
# Example images
|
316 |
# -------------------------------------------------------------------------
|
317 |
|
318 |
+
# canyon_video = "examples/videos/Studlagil_Canyon_East_Iceland.mp4"
|
319 |
great_wall_video = "examples/videos/great_wall.mp4"
|
320 |
colosseum_video = "examples/videos/Colosseum.mp4"
|
321 |
room_video = "examples/videos/room.mp4"
|
|
|
394 |
|
395 |
<h3>Getting Started:</h3>
|
396 |
<ol>
|
397 |
+
<li><strong>Upload Your Data:</strong> Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).</li>
|
398 |
<li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
|
399 |
+
<li><strong>Reconstruct:</strong> Click the "Reconstruct" button to start the 3D reconstruction process.</li>
|
400 |
<li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.</li>
|
401 |
<li>
|
402 |
<strong>Adjust Visualization (Optional):</strong>
|
|
|
408 |
<li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
|
409 |
<li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
|
410 |
<li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
|
411 |
+
<li><em>Select a Prediction Mode:</em> Choose between "Depthmap and Camera Branch" or "Pointmap Branch."</li>
|
412 |
</ul>
|
413 |
</details>
|
414 |
</li>
|
415 |
</ol>
|
416 |
+
<p><strong style="color: #0ea5e9;">Please note:</strong> <span style="color: #0ea5e9; font-weight: bold;">Our method usually only needs less than 1 second to reconstruct a scene, but the visualization of 3D points may take tens of seconds</span>, especially when the number of images is large. Please be patient or, for faster visualization, use a local machine to run our demo from our <a href="https://github.com/facebookresearch/vggt">GitHub repository</a>.</p>
|
417 |
</div>
|
418 |
"""
|
419 |
)
|
420 |
|
|
|
421 |
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
|
422 |
|
423 |
with gr.Row():
|
|
|
473 |
[pyramid_video, "30", None, 35.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
|
474 |
[single_cartoon_video, "1", None, 15.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
|
475 |
[single_oil_painting_video, "1", None, 20.0, False, True, True, True, "Depthmap and Camera Branch", "True"],
|
476 |
+
# [canyon_video, "14", None, 40.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
|
477 |
[room_video, "8", None, 5.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
|
478 |
[kitchen_video, "25", None, 50.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
|
479 |
[fern_video, "20", None, 45.0, False, False, True, False, "Depthmap and Camera Branch", "True"],
|
demo_gradio.py
CHANGED
@@ -18,20 +18,22 @@ import time
|
|
18 |
|
19 |
sys.path.append("vggt/")
|
20 |
|
21 |
-
from
|
22 |
from vggt.models.vggt import VGGT
|
23 |
from vggt.utils.load_fn import load_and_preprocess_images
|
24 |
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
25 |
from vggt.utils.geometry import unproject_depth_map_to_point_map
|
26 |
|
|
|
27 |
|
28 |
print("Initializing and loading VGGT model...")
|
29 |
# model = VGGT.from_pretrained("facebook/VGGT-1B") # another way to load the model
|
30 |
|
31 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
32 |
model = VGGT()
|
33 |
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
|
34 |
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
|
|
|
|
|
35 |
model.eval()
|
36 |
model = model.to(device)
|
37 |
|
@@ -375,35 +377,37 @@ with gr.Blocks(
|
|
375 |
is_example = gr.Textbox(label="is_example", visible=False, value="None")
|
376 |
num_images = gr.Textbox(label="num_images", visible=False, value="None")
|
377 |
|
378 |
-
gr.
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
|
|
|
|
383 |
|
384 |
<div style="font-size: 16px; line-height: 1.5;">
|
385 |
-
<p>Upload a video or a set of images to create a 3D reconstruction of a scene or object.
|
386 |
|
387 |
<h3>Getting Started:</h3>
|
388 |
<ol>
|
389 |
-
<li><strong>Upload Your Data:</strong> Use the
|
390 |
-
<li><strong>Preview:</strong>
|
391 |
-
<li><strong>Reconstruct:</strong> Click the
|
392 |
-
<li><strong>Visualize:</strong>
|
393 |
-
|
394 |
<strong>Adjust Visualization (Optional):</strong>
|
395 |
-
After reconstruction, you can fine-tune the visualization using the options below
|
396 |
<details style="display:inline;">
|
397 |
-
|
398 |
-
|
399 |
<li><em>Confidence Threshold:</em> Adjust the filtering of points based on confidence.</li>
|
400 |
<li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
|
401 |
<li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
|
402 |
<li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
|
403 |
-
<li><em>Select a Prediction Mode:</em> Choose between
|
404 |
-
|
405 |
</details>
|
406 |
-
|
407 |
</ol>
|
408 |
<p><strong>Please note:</strong> Our method usually only needs less than 1 second to reconstruct a scene, but the visualization of 3D points may take tens of seconds, especially when the number of images is large. Please be patient or, for faster visualization, use a local machine to run our demo from our <a href="https://github.com/facebookresearch/vggt">GitHub repository</a>.</p>
|
409 |
</div>
|
|
|
18 |
|
19 |
sys.path.append("vggt/")
|
20 |
|
21 |
+
from visual_util import predictions_to_glb
|
22 |
from vggt.models.vggt import VGGT
|
23 |
from vggt.utils.load_fn import load_and_preprocess_images
|
24 |
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
25 |
from vggt.utils.geometry import unproject_depth_map_to_point_map
|
26 |
|
27 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
|
29 |
print("Initializing and loading VGGT model...")
|
30 |
# model = VGGT.from_pretrained("facebook/VGGT-1B") # another way to load the model
|
31 |
|
|
|
32 |
model = VGGT()
|
33 |
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
|
34 |
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
|
35 |
+
|
36 |
+
|
37 |
model.eval()
|
38 |
model = model.to(device)
|
39 |
|
|
|
377 |
is_example = gr.Textbox(label="is_example", visible=False, value="None")
|
378 |
num_images = gr.Textbox(label="num_images", visible=False, value="None")
|
379 |
|
380 |
+
gr.HTML(
|
381 |
+
"""
|
382 |
+
<h1>ποΈ VGGT: Visual Geometry Grounded Transformer</h1>
|
383 |
+
<p>
|
384 |
+
<a href="https://github.com/facebookresearch/vggt">π GitHub Repository</a> |
|
385 |
+
<a href="#">Project Page</a>
|
386 |
+
</p>
|
387 |
|
388 |
<div style="font-size: 16px; line-height: 1.5;">
|
389 |
+
<p>Upload a video or a set of images to create a 3D reconstruction of a scene or object. VGGT takes these images and generates a 3D point cloud, along with estimated camera poses.</p>
|
390 |
|
391 |
<h3>Getting Started:</h3>
|
392 |
<ol>
|
393 |
+
<li><strong>Upload Your Data:</strong> Use the βUpload Videoβ or βUpload Imagesβ buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).</li>
|
394 |
+
<li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
|
395 |
+
<li><strong>Reconstruct:</strong> Click the βReconstructβ button to start the 3D reconstruction process.</li>
|
396 |
+
<li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.</li>
|
397 |
+
<li>
|
398 |
<strong>Adjust Visualization (Optional):</strong>
|
399 |
+
After reconstruction, you can fine-tune the visualization using the options below
|
400 |
<details style="display:inline;">
|
401 |
+
<summary style="display:inline;">(<strong>click to expand</strong>):</summary>
|
402 |
+
<ul>
|
403 |
<li><em>Confidence Threshold:</em> Adjust the filtering of points based on confidence.</li>
|
404 |
<li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
|
405 |
<li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
|
406 |
<li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
|
407 |
+
<li><em>Select a Prediction Mode:</em> Choose between βDepthmap and Camera Branchβ or βPointmap Branch.β</li>
|
408 |
+
</ul>
|
409 |
</details>
|
410 |
+
</li>
|
411 |
</ol>
|
412 |
<p><strong>Please note:</strong> Our method usually only needs less than 1 second to reconstruct a scene, but the visualization of 3D points may take tens of seconds, especially when the number of images is large. Please be patient or, for faster visualization, use a local machine to run our demo from our <a href="https://github.com/facebookresearch/vggt">GitHub repository</a>.</p>
|
413 |
</div>
|
demo_viser.py
CHANGED
@@ -10,7 +10,6 @@ import time
|
|
10 |
import threading
|
11 |
import argparse
|
12 |
from typing import List, Optional
|
13 |
-
import copy
|
14 |
|
15 |
import numpy as np
|
16 |
import torch
|
@@ -18,12 +17,14 @@ from tqdm.auto import tqdm
|
|
18 |
import viser
|
19 |
import viser.transforms as viser_tf
|
20 |
import cv2
|
21 |
-
|
|
|
22 |
try:
|
23 |
import onnxruntime
|
24 |
except ImportError:
|
25 |
print("onnxruntime not found. Sky segmentation may not work.")
|
26 |
|
|
|
27 |
from vggt.models.vggt import VGGT
|
28 |
from vggt.utils.load_fn import load_and_preprocess_images
|
29 |
from vggt.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map
|
@@ -95,7 +96,7 @@ def viser_wrapper(
|
|
95 |
# Flatten
|
96 |
points = world_points.reshape(-1, 3)
|
97 |
colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8)
|
98 |
-
|
99 |
|
100 |
cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam) # shape (S, 4, 4) typically
|
101 |
# For convenience, we store only (3,4) portion
|
@@ -132,13 +133,12 @@ def viser_wrapper(
|
|
132 |
|
133 |
# Create the main point cloud handle
|
134 |
# Compute the threshold value as the given percentile
|
135 |
-
init_threshold_val = np.percentile(
|
136 |
-
init_conf_mask =
|
137 |
point_cloud = server.scene.add_point_cloud(
|
138 |
name="viser_pcd",
|
139 |
points=points_centered[init_conf_mask],
|
140 |
colors=colors_flat[init_conf_mask],
|
141 |
-
# point_size=0.0001,
|
142 |
point_size=0.001,
|
143 |
point_shape="circle",
|
144 |
)
|
@@ -213,8 +213,11 @@ def viser_wrapper(
|
|
213 |
"""Update the point cloud based on current GUI selections."""
|
214 |
# Here we compute the threshold value based on the current percentage
|
215 |
current_percentage = gui_points_conf.value
|
216 |
-
threshold_val = np.percentile(
|
217 |
-
|
|
|
|
|
|
|
218 |
|
219 |
if gui_frame_selector.value == "All":
|
220 |
frame_mask = np.ones_like(conf_mask, dtype=bool)
|
@@ -264,30 +267,6 @@ def viser_wrapper(
|
|
264 |
|
265 |
# Helper functions for sky segmentation
|
266 |
|
267 |
-
def download_file_from_url(url, filename):
|
268 |
-
"""Downloads a file from a Hugging Face model repo, handling redirects."""
|
269 |
-
try:
|
270 |
-
# Get the redirect URL
|
271 |
-
response = requests.get(url, allow_redirects=False)
|
272 |
-
response.raise_for_status() # Raise HTTPError for bad requests (4xx or 5xx)
|
273 |
-
|
274 |
-
if response.status_code == 302: # Expecting a redirect
|
275 |
-
redirect_url = response.headers["Location"]
|
276 |
-
response = requests.get(redirect_url, stream=True)
|
277 |
-
response.raise_for_status()
|
278 |
-
else:
|
279 |
-
print(f"Unexpected status code: {response.status_code}")
|
280 |
-
return
|
281 |
-
|
282 |
-
with open(filename, "wb") as f:
|
283 |
-
for chunk in response.iter_content(chunk_size=8192):
|
284 |
-
f.write(chunk)
|
285 |
-
print(f"Downloaded {filename} successfully.")
|
286 |
-
|
287 |
-
except requests.exceptions.RequestException as e:
|
288 |
-
print(f"Error downloading file: {e}")
|
289 |
-
|
290 |
-
|
291 |
|
292 |
def apply_sky_segmentation(conf: np.ndarray, image_folder: str) -> np.ndarray:
|
293 |
"""
|
@@ -335,7 +314,7 @@ def apply_sky_segmentation(conf: np.ndarray, image_folder: str) -> np.ndarray:
|
|
335 |
# Convert list to numpy array with shape SΓHΓW
|
336 |
sky_mask_array = np.array(sky_mask_list)
|
337 |
# Apply sky mask to confidence scores
|
338 |
-
sky_mask_binary = (sky_mask_array > 0.
|
339 |
conf = conf * sky_mask_binary
|
340 |
|
341 |
print("Sky segmentation applied successfully")
|
@@ -343,73 +322,6 @@ def apply_sky_segmentation(conf: np.ndarray, image_folder: str) -> np.ndarray:
|
|
343 |
|
344 |
|
345 |
|
346 |
-
def segment_sky(image_path, onnx_session, mask_filename=None):
|
347 |
-
"""
|
348 |
-
Segments sky from an image using an ONNX model.
|
349 |
-
|
350 |
-
Args:
|
351 |
-
image_path: Path to input image
|
352 |
-
onnx_session: ONNX runtime session with loaded model
|
353 |
-
mask_filename: Path to save the output mask
|
354 |
-
|
355 |
-
Returns:
|
356 |
-
np.ndarray: Binary mask where 255 indicates non-sky regions
|
357 |
-
"""
|
358 |
-
assert mask_filename is not None
|
359 |
-
image = cv2.imread(image_path)
|
360 |
-
|
361 |
-
result_map = run_skyseg(onnx_session, [320, 320], image)
|
362 |
-
# resize the result_map to the original image size
|
363 |
-
result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
|
364 |
-
|
365 |
-
output_mask = np.zeros_like(result_map_original)
|
366 |
-
output_mask[result_map_original < 1] = 1
|
367 |
-
output_mask = output_mask.astype(np.uint8) * 255
|
368 |
-
os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
|
369 |
-
cv2.imwrite(mask_filename, output_mask)
|
370 |
-
return output_mask
|
371 |
-
|
372 |
-
|
373 |
-
def run_skyseg(onnx_session, input_size, image):
|
374 |
-
"""
|
375 |
-
Runs sky segmentation inference using ONNX model.
|
376 |
-
|
377 |
-
Args:
|
378 |
-
onnx_session: ONNX runtime session
|
379 |
-
input_size: Target size for model input (width, height)
|
380 |
-
image: Input image in BGR format
|
381 |
-
|
382 |
-
Returns:
|
383 |
-
np.ndarray: Segmentation mask
|
384 |
-
"""
|
385 |
-
# Pre process:Resize, BGR->RGB, Transpose, PyTorch standardization, float32 cast
|
386 |
-
temp_image = copy.deepcopy(image)
|
387 |
-
resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
|
388 |
-
x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
|
389 |
-
x = np.array(x, dtype=np.float32)
|
390 |
-
mean = [0.485, 0.456, 0.406]
|
391 |
-
std = [0.229, 0.224, 0.225]
|
392 |
-
x = (x / 255 - mean) / std
|
393 |
-
x = x.transpose(2, 0, 1)
|
394 |
-
x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
|
395 |
-
|
396 |
-
# Inference
|
397 |
-
input_name = onnx_session.get_inputs()[0].name
|
398 |
-
output_name = onnx_session.get_outputs()[0].name
|
399 |
-
onnx_result = onnx_session.run([output_name], {input_name: x})
|
400 |
-
|
401 |
-
# Post process
|
402 |
-
onnx_result = np.array(onnx_result).squeeze()
|
403 |
-
min_value = np.min(onnx_result)
|
404 |
-
max_value = np.max(onnx_result)
|
405 |
-
onnx_result = (onnx_result - min_value) / (max_value - min_value)
|
406 |
-
onnx_result *= 255
|
407 |
-
onnx_result = onnx_result.astype("uint8")
|
408 |
-
|
409 |
-
return onnx_result
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
|
414 |
|
415 |
|
@@ -450,6 +362,8 @@ def main():
|
|
450 |
print(f"Using device: {device}")
|
451 |
|
452 |
print("Initializing and loading VGGT model...")
|
|
|
|
|
453 |
model = VGGT()
|
454 |
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
|
455 |
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
|
@@ -503,4 +417,4 @@ def main():
|
|
503 |
|
504 |
|
505 |
if __name__ == "__main__":
|
506 |
-
main()
|
|
|
10 |
import threading
|
11 |
import argparse
|
12 |
from typing import List, Optional
|
|
|
13 |
|
14 |
import numpy as np
|
15 |
import torch
|
|
|
17 |
import viser
|
18 |
import viser.transforms as viser_tf
|
19 |
import cv2
|
20 |
+
|
21 |
+
|
22 |
try:
|
23 |
import onnxruntime
|
24 |
except ImportError:
|
25 |
print("onnxruntime not found. Sky segmentation may not work.")
|
26 |
|
27 |
+
from visual_util import segment_sky, download_file_from_url
|
28 |
from vggt.models.vggt import VGGT
|
29 |
from vggt.utils.load_fn import load_and_preprocess_images
|
30 |
from vggt.utils.geometry import closed_form_inverse_se3, unproject_depth_map_to_point_map
|
|
|
96 |
# Flatten
|
97 |
points = world_points.reshape(-1, 3)
|
98 |
colors_flat = (colors.reshape(-1, 3) * 255).astype(np.uint8)
|
99 |
+
conf_flat = conf.reshape(-1)
|
100 |
|
101 |
cam_to_world_mat = closed_form_inverse_se3(extrinsics_cam) # shape (S, 4, 4) typically
|
102 |
# For convenience, we store only (3,4) portion
|
|
|
133 |
|
134 |
# Create the main point cloud handle
|
135 |
# Compute the threshold value as the given percentile
|
136 |
+
init_threshold_val = np.percentile(conf_flat, init_conf_threshold)
|
137 |
+
init_conf_mask = (conf_flat >= init_threshold_val) & (conf_flat > 0.1)
|
138 |
point_cloud = server.scene.add_point_cloud(
|
139 |
name="viser_pcd",
|
140 |
points=points_centered[init_conf_mask],
|
141 |
colors=colors_flat[init_conf_mask],
|
|
|
142 |
point_size=0.001,
|
143 |
point_shape="circle",
|
144 |
)
|
|
|
213 |
"""Update the point cloud based on current GUI selections."""
|
214 |
# Here we compute the threshold value based on the current percentage
|
215 |
current_percentage = gui_points_conf.value
|
216 |
+
threshold_val = np.percentile(conf_flat, current_percentage)
|
217 |
+
|
218 |
+
print(f"Threshold absolute value: {threshold_val}, percentage: {current_percentage}%")
|
219 |
+
|
220 |
+
conf_mask = (conf_flat >= threshold_val) & (conf_flat > 1e-5)
|
221 |
|
222 |
if gui_frame_selector.value == "All":
|
223 |
frame_mask = np.ones_like(conf_mask, dtype=bool)
|
|
|
267 |
|
268 |
# Helper functions for sky segmentation
|
269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
|
271 |
def apply_sky_segmentation(conf: np.ndarray, image_folder: str) -> np.ndarray:
|
272 |
"""
|
|
|
314 |
# Convert list to numpy array with shape SΓHΓW
|
315 |
sky_mask_array = np.array(sky_mask_list)
|
316 |
# Apply sky mask to confidence scores
|
317 |
+
sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32)
|
318 |
conf = conf * sky_mask_binary
|
319 |
|
320 |
print("Sky segmentation applied successfully")
|
|
|
322 |
|
323 |
|
324 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
|
326 |
|
327 |
|
|
|
362 |
print(f"Using device: {device}")
|
363 |
|
364 |
print("Initializing and loading VGGT model...")
|
365 |
+
# model = VGGT.from_pretrained("facebook/VGGT-1B")
|
366 |
+
|
367 |
model = VGGT()
|
368 |
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
|
369 |
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
|
|
|
417 |
|
418 |
|
419 |
if __name__ == "__main__":
|
420 |
+
main()
|
gradio_util.py β visual_util.py
RENAMED
@@ -131,7 +131,7 @@ def predictions_to_glb(
|
|
131 |
sky_mask_array = np.array(sky_mask_list)
|
132 |
|
133 |
# Apply sky mask to confidence scores
|
134 |
-
sky_mask_binary = (sky_mask_array > 0.
|
135 |
pred_world_points_conf = pred_world_points_conf * sky_mask_binary
|
136 |
|
137 |
if selected_frame_idx is not None:
|
@@ -155,7 +155,7 @@ def predictions_to_glb(
|
|
155 |
else:
|
156 |
conf_threshold = np.percentile(conf, conf_thres)
|
157 |
|
158 |
-
conf_mask = conf >= conf_threshold
|
159 |
|
160 |
if mask_black_bg:
|
161 |
black_bg_mask = colors_rgb.sum(axis=1) >= 16
|
@@ -370,6 +370,7 @@ def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
|
|
370 |
def segment_sky(image_path, onnx_session, mask_filename=None):
|
371 |
"""
|
372 |
Segments sky from an image using an ONNX model.
|
|
|
373 |
|
374 |
Args:
|
375 |
image_path: Path to input image
|
@@ -387,9 +388,11 @@ def segment_sky(image_path, onnx_session, mask_filename=None):
|
|
387 |
# resize the result_map to the original image size
|
388 |
result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
|
389 |
|
|
|
|
|
390 |
output_mask = np.zeros_like(result_map_original)
|
391 |
-
output_mask[result_map_original <
|
392 |
-
|
393 |
os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
|
394 |
cv2.imwrite(mask_filename, output_mask)
|
395 |
return output_mask
|
|
|
131 |
sky_mask_array = np.array(sky_mask_list)
|
132 |
|
133 |
# Apply sky mask to confidence scores
|
134 |
+
sky_mask_binary = (sky_mask_array > 0.1).astype(np.float32)
|
135 |
pred_world_points_conf = pred_world_points_conf * sky_mask_binary
|
136 |
|
137 |
if selected_frame_idx is not None:
|
|
|
155 |
else:
|
156 |
conf_threshold = np.percentile(conf, conf_thres)
|
157 |
|
158 |
+
conf_mask = (conf >= conf_threshold) & (conf > 1e-5)
|
159 |
|
160 |
if mask_black_bg:
|
161 |
black_bg_mask = colors_rgb.sum(axis=1) >= 16
|
|
|
370 |
def segment_sky(image_path, onnx_session, mask_filename=None):
|
371 |
"""
|
372 |
Segments sky from an image using an ONNX model.
|
373 |
+
Thanks for the great model provided by https://github.com/xiongzhu666/Sky-Segmentation-and-Post-processing
|
374 |
|
375 |
Args:
|
376 |
image_path: Path to input image
|
|
|
388 |
# resize the result_map to the original image size
|
389 |
result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
|
390 |
|
391 |
+
# Fix: Invert the mask so that 255 = non-sky, 0 = sky
|
392 |
+
# The model outputs low values for sky, high values for non-sky
|
393 |
output_mask = np.zeros_like(result_map_original)
|
394 |
+
output_mask[result_map_original < 32] = 255 # Use threshold of 32
|
395 |
+
|
396 |
os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
|
397 |
cv2.imwrite(mask_filename, output_mask)
|
398 |
return output_mask
|