Last commit not found
# Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# | |
# -------------------------------------------------------- | |
# gradio demo | |
# -------------------------------------------------------- | |
import argparse | |
import math | |
import gradio | |
import os | |
import torch | |
import numpy as np | |
import tempfile | |
import functools | |
import copy | |
from tqdm import tqdm | |
import cv2 | |
from dust3r.inference import inference | |
from dust3r.model import AsymmetricCroCo3DStereo | |
from dust3r.image_pairs import make_pairs | |
from dust3r.utils.image_pose import load_images, rgb, enlarge_seg_masks | |
from dust3r.utils.device import to_numpy | |
from dust3r.cloud_opt_flow import global_aligner, GlobalAlignerMode | |
import matplotlib.pyplot as pl | |
from transformers import pipeline | |
from dust3r.utils.viz_demo import convert_scene_output_to_glb | |
import depth_pro | |
import spaces | |
pl.ion() | |
# for gpu >= Ampere and pytorch >= 1.12 | |
torch.backends.cuda.matmul.allow_tf32 = True | |
batch_size = 1 | |
tmpdirname = tempfile.mkdtemp(suffix='_align3r_gradio_demo') | |
image_size = 512 | |
silent = True | |
gradio_delete_cache = 7200 | |
class FileState: | |
def __init__(self, outfile_name=None): | |
self.outfile_name = outfile_name | |
def __del__(self): | |
if self.outfile_name is not None and os.path.isfile(self.outfile_name): | |
os.remove(self.outfile_name) | |
self.outfile_name = None | |
def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False, | |
clean_depth=False, transparent_cams=False, cam_size=0.05, show_cam=True, save_name=None, thr_for_init_conf=True): | |
""" | |
extract 3D_model (glb file) from a reconstructed scene | |
""" | |
if scene is None: | |
return None | |
# post processes | |
if clean_depth: | |
scene = scene.clean_pointcloud() | |
if mask_sky: | |
scene = scene.mask_sky() | |
# get optimized values from scene | |
rgbimg = scene.imgs | |
focals = scene.get_focals().cpu() | |
cams2world = scene.get_im_poses().cpu() | |
# 3D pointcloud from depthmap, poses and intrinsics | |
pts3d = to_numpy(scene.get_pts3d(raw_pts=True)) | |
scene.min_conf_thr = min_conf_thr | |
scene.thr_for_init_conf = thr_for_init_conf | |
msk = to_numpy(scene.get_masks()) | |
cmap = pl.get_cmap('viridis') | |
cam_color = [cmap(i/len(rgbimg))[:3] for i in range(len(rgbimg))] | |
cam_color = [(255*c[0], 255*c[1], 255*c[2]) for c in cam_color] | |
return convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud, | |
transparent_cams=transparent_cams, cam_size=cam_size, show_cam=show_cam, silent=silent, save_name=save_name, | |
cam_color=cam_color) | |
def generate_monocular_depth_maps(img_list, depth_prior_name): | |
depth_list = [] | |
focallength_px_list = [] | |
if depth_prior_name=='depthpro': | |
model, transform = depth_pro.create_model_and_transforms(device='cuda') | |
model.eval() | |
for image_path in tqdm(img_list): | |
#path_depthpro = image_path.replace('.png','_pred_depth_depthpro.npz').replace('.jpg','_pred_depth_depthpro.npz') | |
image, _, f_px = depth_pro.load_rgb(image_path) | |
image = transform(image) | |
# Run inference. | |
prediction = model.infer(image, f_px=f_px) | |
depth = prediction["depth"].cpu() # Depth in [m]. | |
focallength_px=prediction["focallength_px"].cpu() | |
depth_list.append(depth) | |
focallength_px_list.append(focallength_px) | |
#np.savez_compressed(path_depthpro, depth=depth, focallength_px=prediction["focallength_px"].cpu()) | |
elif depth_prior_name=='depthanything': | |
pipe = pipeline(task="depth-estimation", model="depth-anything/Depth-Anything-V2-Large-hf",device='cuda') | |
for image_path in tqdm(img_list): | |
#path_depthanything = image_path.replace('.png','_pred_depth_depthanything.npz').replace('.jpg','_pred_depth_depthanything.npz') | |
image = Image.open(image_path) | |
depth = pipe(image)["predicted_depth"].numpy() | |
focallength_px = 200 | |
depth_list.append(depth) | |
focallength_px_list.append(focallength_px) | |
#np.savez_compressed(path_depthanything, depth=depth) | |
return depth_list, focallength_px_list | |
def local_get_reconstructed_scene(filelist, min_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, depth_prior_name, **kw): | |
depth_list, focallength_px_list = generate_monocular_depth_maps(filelist, depth_prior_name) | |
imgs = load_images(filelist, depth_list, focallength_px_list, size=image_size, verbose=not silent,traj_format='custom', depth_prior_name=depth_prior_name) | |
pairs = [] | |
pairs.append((imgs[0], imgs[1])) | |
output = inference(pairs, model, device, batch_size=batch_size, verbose=not silent) | |
mode = GlobalAlignerMode.PairViewer | |
scene = global_aligner(output, device=device, mode=mode, verbose=not silent) | |
save_folder = './output' | |
outfile = get_3D_model_from_scene(save_folder, silent, scene, min_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, show_cam) | |
return outfile | |
def run_example(snapshot, matching_conf_thr, min_conf_thr, cam_size, as_pointcloud, shared_intrinsics, filelist, **kw): | |
return local_get_reconstructed_scene(filelist, cam_size, **kw) | |
css = """.gradio-container {margin: 0 !important; min-width: 100%};""" | |
title = "Align3R Demo" | |
with gradio.Blocks(css=css, title=title, delete_cache=(gradio_delete_cache, gradio_delete_cache)) as demo: | |
filestate = gradio.State(None) | |
gradio.HTML('<h2 style="text-align: center;">3D Reconstruction with Align3R</h2>') | |
gradio.HTML('<p>Upload two images (wait for them to be fully uploaded before hitting the run button). ' | |
'If you want to try larger image collections, you can find the more complete version of this demo that you can run locally ' | |
'and more details about the method at <a href="https://github.com/jiah-cloud/Align3R">github.com/jiah-cloud/Align3R</a>. ' | |
'The checkpoint used in this demo is available at <a href="https://huggingface.co/cyun9286/Align3R_DepthAnythingV2_ViTLarge_BaseDecoder_512_dpt">Align3R (Depth Anything V2)</a> and <a href="https://huggingface.co/cyun9286/Align3R_DepthPro_ViTLarge_BaseDecoder_512_dpt">Align3R (Depth Pro)</a>.</p>') | |
with gradio.Column(): | |
inputfiles = gradio.File(file_count="multiple") | |
snapshot = gradio.Image(None, visible=False) | |
with gradio.Row(): | |
# adjust the camera size in the output pointcloud | |
cam_size = gradio.Slider(label="cam_size", value=0.2, minimum=0.001, maximum=1.0, step=0.001) | |
depth_prior_name = gradio.Dropdown( | |
["Depth Pro", "Depth Anything V2"], label="monocular depth estimation model", info="Select the monocular depth estimation model.") | |
min_conf_thr = gradio.Slider(label="min_conf_thr", value=1.1, minimum=0.0, maximum=20, step=0.01) | |
if depth_prior_name == "Depth Pro": | |
weights_path = "cyun9286/Align3R_DepthPro_ViTLarge_BaseDecoder_512_dpt" | |
else: | |
weights_path = "cyun9286/Align3R_DepthAnythingV2_ViTLarge_BaseDecoder_512_dpt" | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(device) | |
with gradio.Row(): | |
as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud") | |
mask_sky = gradio.Checkbox(value=False, label="Mask sky") | |
clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps") | |
transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras") | |
# not to show camera | |
show_cam = gradio.Checkbox(value=True, label="Show Camera") | |
run_btn = gradio.Button("Run") | |
outmodel = gradio.Model3D() | |
# examples = gradio.Examples( | |
# examples=[ | |
# ['./example/yellowman/frame_0003.png', | |
# 0.0, 1.5, 0.2, True, False, | |
# ] | |
# ], | |
# inputs=[snapshot, matching_conf_thr, min_conf_thr, cam_size, as_pointcloud, shared_intrinsics, inputfiles], | |
# outputs=[filestate, outmodel], | |
# fn=run_example, | |
# cache_examples="lazy", | |
# ) | |
# events | |
run_btn.click(fn=local_get_reconstructed_scene, | |
inputs=[inputfiles, min_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, depth_prior_name], | |
outputs=[outmodel]) | |
demo.launch(show_error=True, share=None, server_name=None, server_port=None) | |
shutil.rmtree(tmpdirname) |