|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import pycolmap |
|
import gradio |
|
import os |
|
import numpy as np |
|
import functools |
|
import trimesh |
|
import copy |
|
from scipy.spatial.transform import Rotation |
|
import tempfile |
|
import shutil |
|
import PIL.Image |
|
import torch |
|
|
|
from kapture.converter.colmap.database_extra import kapture_to_colmap |
|
from kapture.converter.colmap.database import COLMAPDatabase |
|
|
|
from mast3r.colmap.mapping import kapture_import_image_folder_or_list, run_mast3r_matching, glomap_run_mapper |
|
from mast3r.demo import set_scenegraph_options |
|
from mast3r.retrieval.processor import Retriever |
|
from mast3r.image_pairs import make_pairs |
|
|
|
import mast3r.utils.path_to_dust3r |
|
from dust3r.utils.image import load_images |
|
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL |
|
from dust3r.demo import get_args_parser as dust3r_get_args_parser |
|
|
|
import matplotlib.pyplot as pl |
|
|
|
|
|
class GlomapRecon: |
|
def __init__(self, world_to_cam, intrinsics, points3d, imgs): |
|
self.world_to_cam = world_to_cam |
|
self.intrinsics = intrinsics |
|
self.points3d = points3d |
|
self.imgs = imgs |
|
|
|
|
|
class GlomapReconState: |
|
def __init__(self, glomap_recon, should_delete=False, cache_dir=None, outfile_name=None): |
|
self.glomap_recon = glomap_recon |
|
self.cache_dir = cache_dir |
|
self.outfile_name = outfile_name |
|
self.should_delete = should_delete |
|
|
|
def __del__(self): |
|
if not self.should_delete: |
|
return |
|
if self.cache_dir is not None and os.path.isdir(self.cache_dir): |
|
shutil.rmtree(self.cache_dir) |
|
self.cache_dir = None |
|
if self.outfile_name is not None and os.path.isfile(self.outfile_name): |
|
os.remove(self.outfile_name) |
|
self.outfile_name = None |
|
|
|
|
|
def get_args_parser(): |
|
parser = dust3r_get_args_parser() |
|
parser.add_argument('--share', action='store_true') |
|
parser.add_argument('--gradio_delete_cache', default=None, type=int, |
|
help='age/frequency at which gradio removes the file. If >0, matching cache is purged') |
|
parser.add_argument('--glomap_bin', default='glomap', type=str, help='glomap bin') |
|
parser.add_argument('--retrieval_model', default=None, type=str, help="retrieval_model to be loaded") |
|
|
|
actions = parser._actions |
|
for action in actions: |
|
if action.dest == 'model_name': |
|
action.choices = ["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"] |
|
|
|
parser.prog = 'mast3r demo' |
|
return parser |
|
|
|
|
|
def get_reconstructed_scene(glomap_bin, outdir, gradio_delete_cache, model, retrieval_model, device, silent, image_size, |
|
current_scene_state, filelist, transparent_cams, cam_size, scenegraph_type, winsize, |
|
win_cyclic, refid, shared_intrinsics, **kw): |
|
""" |
|
from a list of images, run mast3r inference, sparse global aligner. |
|
then run get_3D_model_from_scene |
|
""" |
|
imgs = load_images(filelist, size=image_size, verbose=not silent) |
|
if len(imgs) == 1: |
|
imgs = [imgs[0], copy.deepcopy(imgs[0])] |
|
imgs[1]['idx'] = 1 |
|
filelist = [filelist[0], filelist[0]] |
|
|
|
scene_graph_params = [scenegraph_type] |
|
if scenegraph_type in ["swin", "logwin"]: |
|
scene_graph_params.append(str(winsize)) |
|
elif scenegraph_type == "oneref": |
|
scene_graph_params.append(str(refid)) |
|
elif scenegraph_type == "retrieval": |
|
scene_graph_params.append(str(winsize)) |
|
scene_graph_params.append(str(refid)) |
|
|
|
if scenegraph_type in ["swin", "logwin"] and not win_cyclic: |
|
scene_graph_params.append('noncyclic') |
|
scene_graph = '-'.join(scene_graph_params) |
|
|
|
sim_matrix = None |
|
if 'retrieval' in scenegraph_type: |
|
assert retrieval_model is not None |
|
retriever = Retriever(retrieval_model, backbone=model, device=device) |
|
with torch.no_grad(): |
|
sim_matrix = retriever(filelist) |
|
|
|
|
|
del retriever |
|
torch.cuda.empty_cache() |
|
|
|
pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True, sim_mat=sim_matrix) |
|
|
|
if current_scene_state is not None and \ |
|
not current_scene_state.should_delete and \ |
|
current_scene_state.cache_dir is not None: |
|
cache_dir = current_scene_state.cache_dir |
|
elif gradio_delete_cache: |
|
cache_dir = tempfile.mkdtemp(suffix='_cache', dir=outdir) |
|
else: |
|
cache_dir = os.path.join(outdir, 'cache') |
|
|
|
root_path = os.path.commonpath(filelist) |
|
filelist_relpath = [ |
|
os.path.relpath(filename, root_path).replace('\\', '/') |
|
for filename in filelist |
|
] |
|
kdata = kapture_import_image_folder_or_list((root_path, filelist_relpath), shared_intrinsics) |
|
image_pairs = [ |
|
(filelist_relpath[img1['idx']], filelist_relpath[img2['idx']]) |
|
for img1, img2 in pairs |
|
] |
|
|
|
colmap_db_path = os.path.join(cache_dir, 'colmap.db') |
|
if os.path.isfile(colmap_db_path): |
|
os.remove(colmap_db_path) |
|
|
|
os.makedirs(os.path.dirname(colmap_db_path), exist_ok=True) |
|
colmap_db = COLMAPDatabase.connect(colmap_db_path) |
|
try: |
|
kapture_to_colmap(kdata, root_path, tar_handler=None, database=colmap_db, |
|
keypoints_type=None, descriptors_type=None, export_two_view_geometry=False) |
|
colmap_image_pairs = run_mast3r_matching(model, image_size, 16, device, |
|
kdata, root_path, image_pairs, colmap_db, |
|
False, 5, 1.001, |
|
False, 3) |
|
colmap_db.close() |
|
except Exception as e: |
|
print(f'Error {e}') |
|
colmap_db.close() |
|
exit(1) |
|
|
|
if len(colmap_image_pairs) == 0: |
|
raise Exception("no matches were kept") |
|
|
|
|
|
colmap_world_to_cam = {} |
|
print("verify_matches") |
|
f = open(cache_dir + '/pairs.txt', "w") |
|
for image_path1, image_path2 in colmap_image_pairs: |
|
f.write("{} {}\n".format(image_path1, image_path2)) |
|
f.close() |
|
pycolmap.verify_matches(colmap_db_path, cache_dir + '/pairs.txt') |
|
|
|
reconstruction_path = os.path.join(cache_dir, "reconstruction") |
|
if os.path.isdir(reconstruction_path): |
|
shutil.rmtree(reconstruction_path) |
|
os.makedirs(reconstruction_path, exist_ok=True) |
|
glomap_run_mapper(glomap_bin, colmap_db_path, reconstruction_path, root_path) |
|
|
|
if current_scene_state is not None and \ |
|
not current_scene_state.should_delete and \ |
|
current_scene_state.outfile_name is not None: |
|
outfile_name = current_scene_state.outfile_name |
|
else: |
|
outfile_name = tempfile.mktemp(suffix='_scene.glb', dir=outdir) |
|
|
|
ouput_recon = pycolmap.Reconstruction(os.path.join(reconstruction_path, '0')) |
|
print(ouput_recon.summary()) |
|
|
|
colmap_world_to_cam = {} |
|
colmap_intrinsics = {} |
|
colmap_image_id_to_name = {} |
|
images = {} |
|
num_reg_images = ouput_recon.num_reg_images() |
|
for idx, (colmap_imgid, colmap_image) in enumerate(ouput_recon.images.items()): |
|
colmap_image_id_to_name[colmap_imgid] = colmap_image.name |
|
if callable(colmap_image.cam_from_world.matrix): |
|
colmap_world_to_cam[colmap_imgid] = colmap_image.cam_from_world.matrix( |
|
) |
|
else: |
|
colmap_world_to_cam[colmap_imgid] = colmap_image.cam_from_world.matrix |
|
camera = ouput_recon.cameras[colmap_image.camera_id] |
|
K = np.eye(3) |
|
K[0, 0] = camera.focal_length_x |
|
K[1, 1] = camera.focal_length_y |
|
K[0, 2] = camera.principal_point_x |
|
K[1, 2] = camera.principal_point_y |
|
colmap_intrinsics[colmap_imgid] = K |
|
|
|
with PIL.Image.open(os.path.join(root_path, colmap_image.name)) as im: |
|
images[colmap_imgid] = np.asarray(im) |
|
|
|
if idx + 1 == num_reg_images: |
|
break |
|
points3D = [] |
|
num_points3D = ouput_recon.num_points3D() |
|
for idx, (pt3d_id, pts3d) in enumerate(ouput_recon.points3D.items()): |
|
points3D.append((pts3d.xyz, pts3d.color)) |
|
if idx + 1 == num_points3D: |
|
break |
|
scene = GlomapRecon(colmap_world_to_cam, colmap_intrinsics, points3D, images) |
|
scene_state = GlomapReconState(scene, gradio_delete_cache, cache_dir, outfile_name) |
|
outfile = get_3D_model_from_scene(silent, scene_state, transparent_cams, cam_size) |
|
return scene_state, outfile |
|
|
|
|
|
def get_3D_model_from_scene(silent, scene_state, transparent_cams=False, cam_size=0.05): |
|
""" |
|
extract 3D_model (glb file) from a reconstructed scene |
|
""" |
|
if scene_state is None: |
|
return None |
|
outfile = scene_state.outfile_name |
|
if outfile is None: |
|
return None |
|
|
|
recon = scene_state.glomap_recon |
|
|
|
scene = trimesh.Scene() |
|
pts = np.stack([p[0] for p in recon.points3d], axis=0) |
|
col = np.stack([p[1] for p in recon.points3d], axis=0) |
|
pct = trimesh.PointCloud(pts, colors=col) |
|
scene.add_geometry(pct) |
|
|
|
|
|
cams2world = [] |
|
for i, (id, pose_w2c_3x4) in enumerate(recon.world_to_cam.items()): |
|
intrinsics = recon.intrinsics[id] |
|
focal = (intrinsics[0, 0] + intrinsics[1, 1]) / 2.0 |
|
camera_edge_color = CAM_COLORS[i % len(CAM_COLORS)] |
|
pose_w2c = np.eye(4) |
|
pose_w2c[:3, :] = pose_w2c_3x4 |
|
pose_c2w = np.linalg.inv(pose_w2c) |
|
cams2world.append(pose_c2w) |
|
add_scene_cam(scene, pose_c2w, camera_edge_color, |
|
None if transparent_cams else recon.imgs[id], focal, |
|
imsize=recon.imgs[id].shape[1::-1], screen_width=cam_size) |
|
|
|
rot = np.eye(4) |
|
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix() |
|
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot)) |
|
if not silent: |
|
print('(exporting 3D scene to', outfile, ')') |
|
scene.export(file_obj=outfile) |
|
|
|
return outfile |
|
|
|
|
|
def main_demo(glomap_bin, tmpdirname, model, retrieval_model, device, image_size, server_name, server_port, |
|
silent=False, share=False, gradio_delete_cache=False): |
|
if not silent: |
|
print('Outputing stuff in', tmpdirname) |
|
|
|
recon_fun = functools.partial(get_reconstructed_scene, glomap_bin, tmpdirname, gradio_delete_cache, model, |
|
retrieval_model, device, silent, image_size) |
|
model_from_scene_fun = functools.partial(get_3D_model_from_scene, silent) |
|
|
|
available_scenegraph_type = [("complete: all possible image pairs", "complete"), |
|
("swin: sliding window", "swin"), |
|
("logwin: sliding window with long range", "logwin"), |
|
("oneref: match one image with all", "oneref")] |
|
if retrieval_model is not None: |
|
available_scenegraph_type.insert(1, ("retrieval: connect views based on similarity", "retrieval")) |
|
|
|
def get_context(delete_cache): |
|
css = """.gradio-container {margin: 0 !important; min-width: 100%};""" |
|
title = "MASt3R Demo" |
|
if delete_cache: |
|
return gradio.Blocks(css=css, title=title, delete_cache=(delete_cache, delete_cache)) |
|
else: |
|
return gradio.Blocks(css=css, title="MASt3R Demo") |
|
|
|
with get_context(gradio_delete_cache) as demo: |
|
|
|
scene = gradio.State(None) |
|
gradio.HTML('<h2 style="text-align: center;">MASt3R Demo</h2>') |
|
with gradio.Column(): |
|
inputfiles = gradio.File(file_count="multiple") |
|
with gradio.Row(): |
|
shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics", |
|
info="Only optimize one set of intrinsics for all views") |
|
scenegraph_type = gradio.Dropdown(available_scenegraph_type, |
|
value='complete', label="Scenegraph", |
|
info="Define how to make pairs", |
|
interactive=True) |
|
with gradio.Column(visible=False) as win_col: |
|
winsize = gradio.Slider(label="Scene Graph: Window Size", value=1, |
|
minimum=1, maximum=1, step=1) |
|
win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence") |
|
refid = gradio.Slider(label="Scene Graph: Id", value=0, |
|
minimum=0, maximum=0, step=1, visible=False) |
|
run_btn = gradio.Button("Run") |
|
|
|
with gradio.Row(): |
|
|
|
cam_size = gradio.Slider(label="cam_size", value=0.01, minimum=0.001, maximum=1.0, step=0.001) |
|
with gradio.Row(): |
|
transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras") |
|
|
|
outmodel = gradio.Model3D() |
|
|
|
|
|
scenegraph_type.change(set_scenegraph_options, |
|
inputs=[inputfiles, win_cyclic, refid, scenegraph_type], |
|
outputs=[win_col, winsize, win_cyclic, refid]) |
|
inputfiles.change(set_scenegraph_options, |
|
inputs=[inputfiles, win_cyclic, refid, scenegraph_type], |
|
outputs=[win_col, winsize, win_cyclic, refid]) |
|
win_cyclic.change(set_scenegraph_options, |
|
inputs=[inputfiles, win_cyclic, refid, scenegraph_type], |
|
outputs=[win_col, winsize, win_cyclic, refid]) |
|
run_btn.click(fn=recon_fun, |
|
inputs=[scene, inputfiles, transparent_cams, cam_size, |
|
scenegraph_type, winsize, win_cyclic, refid, shared_intrinsics], |
|
outputs=[scene, outmodel]) |
|
cam_size.change(fn=model_from_scene_fun, |
|
inputs=[scene, transparent_cams, cam_size], |
|
outputs=outmodel) |
|
transparent_cams.change(model_from_scene_fun, |
|
inputs=[scene, transparent_cams, cam_size], |
|
outputs=outmodel) |
|
demo.launch(share=share, server_name=server_name, server_port=server_port) |
|
|