diff --git a/dust3r/__init__.py b/dust3r/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e
--- /dev/null
+++ b/dust3r/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
diff --git a/dust3r/__pycache__/__init__.cpython-310.pyc b/dust3r/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2554667f2a3dea851493e035792ac143bd10cea4
Binary files /dev/null and b/dust3r/__pycache__/__init__.cpython-310.pyc differ
diff --git a/dust3r/__pycache__/__init__.cpython-38.pyc b/dust3r/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bb5d6c4f66d26ea8bee390058b16d00ba0a178b7
Binary files /dev/null and b/dust3r/__pycache__/__init__.cpython-38.pyc differ
diff --git a/dust3r/__pycache__/__init__.cpython-39.pyc b/dust3r/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d436cad8c59aee2b22ff5e73c7c77f8ac5eb826c
Binary files /dev/null and b/dust3r/__pycache__/__init__.cpython-39.pyc differ
diff --git a/dust3r/__pycache__/image_pairs.cpython-310.pyc b/dust3r/__pycache__/image_pairs.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..197341b3fd8a6d4675a7446bcdf9520c82dac153
Binary files /dev/null and b/dust3r/__pycache__/image_pairs.cpython-310.pyc differ
diff --git a/dust3r/__pycache__/image_pairs.cpython-38.pyc b/dust3r/__pycache__/image_pairs.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e9b5a6514e633f89cbbf4fab38e705a6d2714dc0
Binary files /dev/null and b/dust3r/__pycache__/image_pairs.cpython-38.pyc differ
diff --git a/dust3r/__pycache__/inference.cpython-310.pyc b/dust3r/__pycache__/inference.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a53990954586fcc38cbb84d0f03b5c54136cfa74
Binary files /dev/null and b/dust3r/__pycache__/inference.cpython-310.pyc differ
diff --git a/dust3r/__pycache__/inference.cpython-38.pyc b/dust3r/__pycache__/inference.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f566e7340dcfb332c4c7e388cb7bb3157b289528
Binary files /dev/null and b/dust3r/__pycache__/inference.cpython-38.pyc differ
diff --git a/dust3r/__pycache__/inference.cpython-39.pyc b/dust3r/__pycache__/inference.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3bce1d02d8e940f348f3373351f2264ff1b634fd
Binary files /dev/null and b/dust3r/__pycache__/inference.cpython-39.pyc differ
diff --git a/dust3r/__pycache__/model.cpython-310.pyc b/dust3r/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bba7efd5fa18e8843337b8c4d2e3a0f568eb1710
Binary files /dev/null and b/dust3r/__pycache__/model.cpython-310.pyc differ
diff --git a/dust3r/__pycache__/model.cpython-38.pyc b/dust3r/__pycache__/model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..421285c1abb1bd2c92255c70131c55b6eb962798
Binary files /dev/null and b/dust3r/__pycache__/model.cpython-38.pyc differ
diff --git a/dust3r/__pycache__/model.cpython-39.pyc b/dust3r/__pycache__/model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8b3fc91a79d56428b58d8a9aec928db114981371
Binary files /dev/null and b/dust3r/__pycache__/model.cpython-39.pyc differ
diff --git a/dust3r/__pycache__/optim_factory.cpython-310.pyc b/dust3r/__pycache__/optim_factory.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c4675b7111b6496670413ab3d7ac69e05e62817
Binary files /dev/null and b/dust3r/__pycache__/optim_factory.cpython-310.pyc differ
diff --git a/dust3r/__pycache__/optim_factory.cpython-38.pyc b/dust3r/__pycache__/optim_factory.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b86abf43210c3a958beee358b83bc4f25300f57
Binary files /dev/null and b/dust3r/__pycache__/optim_factory.cpython-38.pyc differ
diff --git a/dust3r/__pycache__/patch_embed.cpython-310.pyc b/dust3r/__pycache__/patch_embed.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d62cef41468bc157c09203e594f036c419ecc7f
Binary files /dev/null and b/dust3r/__pycache__/patch_embed.cpython-310.pyc differ
diff --git a/dust3r/__pycache__/patch_embed.cpython-38.pyc b/dust3r/__pycache__/patch_embed.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b296aa18c05ec232f66356b618da42d13f88e81
Binary files /dev/null and b/dust3r/__pycache__/patch_embed.cpython-38.pyc differ
diff --git a/dust3r/__pycache__/post_process.cpython-310.pyc b/dust3r/__pycache__/post_process.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7e639f8b862956da232cab176c2f1c34be121a97
Binary files /dev/null and b/dust3r/__pycache__/post_process.cpython-310.pyc differ
diff --git a/dust3r/__pycache__/render_to_3d.cpython-310.pyc b/dust3r/__pycache__/render_to_3d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..596ac63de7b993b805ba9f48263907747817e56d
Binary files /dev/null and b/dust3r/__pycache__/render_to_3d.cpython-310.pyc differ
diff --git a/dust3r/__pycache__/viz.cpython-310.pyc b/dust3r/__pycache__/viz.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..31c936fe6b04a03e774e7d3096862034b882e61f
Binary files /dev/null and b/dust3r/__pycache__/viz.cpython-310.pyc differ
diff --git a/dust3r/__pycache__/viz.cpython-38.pyc b/dust3r/__pycache__/viz.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..348ee4235b6cc92884dcea9c2c6b586da932ea57
Binary files /dev/null and b/dust3r/__pycache__/viz.cpython-38.pyc differ
diff --git a/dust3r/cloud_opt/__init__.py b/dust3r/cloud_opt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc597c702861154bbe7a08f23b089474e926bb35
--- /dev/null
+++ b/dust3r/cloud_opt/__init__.py
@@ -0,0 +1,29 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# global alignment optimization wrapper function
+# --------------------------------------------------------
+from enum import Enum
+
+from .optimizer import PointCloudOptimizer
+from .pair_viewer import PairViewer
+
+
+class GlobalAlignerMode(Enum):
+    PointCloudOptimizer = "PointCloudOptimizer"
+    PairViewer = "PairViewer"
+
+
+def global_aligner(dust3r_output, device, mode=GlobalAlignerMode.PointCloudOptimizer, **optim_kw):
+    # extract all inputs
+    view1, view2, pred1, pred2 = [dust3r_output[k] for k in 'view1 view2 pred1 pred2'.split()]
+    # build the optimizer
+    if mode == GlobalAlignerMode.PointCloudOptimizer:
+        net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device)
+    elif mode == GlobalAlignerMode.PairViewer:
+        net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device)
+    else:
+        raise NotImplementedError(f'Unknown mode {mode}')
+
+    return net
diff --git a/dust3r/cloud_opt/__pycache__/__init__.cpython-310.pyc b/dust3r/cloud_opt/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5e0a2b51c03de3e985bd0aaa8fdcca3df53afc24
Binary files /dev/null and b/dust3r/cloud_opt/__pycache__/__init__.cpython-310.pyc differ
diff --git a/dust3r/cloud_opt/__pycache__/__init__.cpython-38.pyc b/dust3r/cloud_opt/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6f50c36cc3fd01bee45dd96d2b4b74abc5020bb8
Binary files /dev/null and b/dust3r/cloud_opt/__pycache__/__init__.cpython-38.pyc differ
diff --git a/dust3r/cloud_opt/__pycache__/base_opt.cpython-310.pyc b/dust3r/cloud_opt/__pycache__/base_opt.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..966f34000eac726ecbae1176d59f36ecd8d5689d
Binary files /dev/null and b/dust3r/cloud_opt/__pycache__/base_opt.cpython-310.pyc differ
diff --git a/dust3r/cloud_opt/__pycache__/base_opt.cpython-38.pyc b/dust3r/cloud_opt/__pycache__/base_opt.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5c72e158564414b21d654a3bfb09deec9c8062e0
Binary files /dev/null and b/dust3r/cloud_opt/__pycache__/base_opt.cpython-38.pyc differ
diff --git a/dust3r/cloud_opt/__pycache__/commons.cpython-310.pyc b/dust3r/cloud_opt/__pycache__/commons.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..80a51ff436858899e0394ec879a6fefea4dca7fd
Binary files /dev/null and b/dust3r/cloud_opt/__pycache__/commons.cpython-310.pyc differ
diff --git a/dust3r/cloud_opt/__pycache__/commons.cpython-38.pyc b/dust3r/cloud_opt/__pycache__/commons.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c309db1757053214244504ee4ef54568ea10ec1d
Binary files /dev/null and b/dust3r/cloud_opt/__pycache__/commons.cpython-38.pyc differ
diff --git a/dust3r/cloud_opt/__pycache__/init_im_poses.cpython-310.pyc b/dust3r/cloud_opt/__pycache__/init_im_poses.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..028d8d55934e604ed7b856eb501b9fd59a55e4ea
Binary files /dev/null and b/dust3r/cloud_opt/__pycache__/init_im_poses.cpython-310.pyc differ
diff --git a/dust3r/cloud_opt/__pycache__/init_im_poses.cpython-38.pyc b/dust3r/cloud_opt/__pycache__/init_im_poses.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3b3d310d3d6522379fb57d29d123ecbb5d6392f5
Binary files /dev/null and b/dust3r/cloud_opt/__pycache__/init_im_poses.cpython-38.pyc differ
diff --git a/dust3r/cloud_opt/__pycache__/optimizer.cpython-310.pyc b/dust3r/cloud_opt/__pycache__/optimizer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71100cc85b822d51caf38707f40d81cd3639d97e
Binary files /dev/null and b/dust3r/cloud_opt/__pycache__/optimizer.cpython-310.pyc differ
diff --git a/dust3r/cloud_opt/__pycache__/optimizer.cpython-38.pyc b/dust3r/cloud_opt/__pycache__/optimizer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..79fa1b943f67dddbd8c4e0beafb2a19a365c5825
Binary files /dev/null and b/dust3r/cloud_opt/__pycache__/optimizer.cpython-38.pyc differ
diff --git a/dust3r/cloud_opt/__pycache__/pair_viewer.cpython-310.pyc b/dust3r/cloud_opt/__pycache__/pair_viewer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..78f5a5e6821b1b033f1154061b87d3631dae320c
Binary files /dev/null and b/dust3r/cloud_opt/__pycache__/pair_viewer.cpython-310.pyc differ
diff --git a/dust3r/cloud_opt/base_opt.py b/dust3r/cloud_opt/base_opt.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ff0b04e9091bd1283b6302c4a3f166a17908e2d
--- /dev/null
+++ b/dust3r/cloud_opt/base_opt.py
@@ -0,0 +1,380 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# Base class for the global alignement procedure
+# --------------------------------------------------------
+from copy import deepcopy
+
+import numpy as np
+import torch
+import torch.nn as nn
+import roma
+from copy import deepcopy
+import tqdm
+
+from dust3r.utils.geometry import inv, geotrf
+from dust3r.utils.device import to_numpy
+from dust3r.utils.image import rgb
+from dust3r.viz import SceneViz, segment_sky, auto_cam_size
+from dust3r.optim_factory import adjust_learning_rate_by_lr
+
+from dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p,
+                                      cosine_schedule, linear_schedule, get_conf_trf)
+import dust3r.cloud_opt.init_im_poses as init_fun
+
+
+class BasePCOptimizer (nn.Module):
+    """ Optimize a global scene, given a list of pairwise observations.
+    Graph node: images
+    Graph edges: observations = (pred1, pred2)
+    """
+
+    def __init__(self, *args, **kwargs):
+        if len(args) == 1 and len(kwargs) == 0:
+            other = deepcopy(args[0])
+            attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes 
+                        min_conf_thr conf_thr conf_i conf_j im_conf
+                        base_scale norm_pw_scale POSE_DIM pw_poses 
+                        pw_adaptors pw_adaptors has_im_poses rand_pose imgs'''.split()
+            self.__dict__.update({k: other[k] for k in attrs})
+        else:
+            self._init_from_views(*args, **kwargs)
+
+    def _init_from_views(self, view1, view2, pred1, pred2,
+                         dist='l1',
+                         conf='log',
+                         min_conf_thr=3,
+                         base_scale=0.5,
+                         allow_pw_adaptors=False,
+                         pw_break=20,
+                         rand_pose=torch.randn,
+                         iterationsCount=None,
+                        ):
+        super().__init__()
+        if not isinstance(view1['idx'], list):
+            view1['idx'] = view1['idx'].tolist()
+        if not isinstance(view2['idx'], list):
+            view2['idx'] = view2['idx'].tolist()
+        self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]
+        self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges}
+        self.dist = ALL_DISTS[dist]
+
+
+        self.n_imgs = self._check_edges()
+
+        # input data
+        pred1_pts = pred1['pts3d']
+        pred2_pts = pred2['pts3d_in_other_view']
+        self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)})
+        self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)})
+        self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts)
+
+        # work in log-scale with conf
+        pred1_conf = pred1['conf']
+        pred2_conf = pred2['conf']
+        self.min_conf_thr = min_conf_thr
+        self.conf_trf = get_conf_trf(conf)
+
+        self.conf_i = NoGradParamDict({ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)})
+        self.conf_j = NoGradParamDict({ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)})
+        self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf)
+
+        # pairwise pose parameters
+        self.base_scale = base_scale
+        self.norm_pw_scale = True
+        self.pw_break = pw_break
+        self.POSE_DIM = 7
+        self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM)))  # pairwise poses
+        self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2)))  # slight xy/z adaptation
+        self.pw_adaptors.requires_grad_(allow_pw_adaptors)
+        self.has_im_poses = False
+        self.rand_pose = rand_pose
+
+        # possibly store images for show_pointcloud
+        self.imgs = None
+        if 'img' in view1 and 'img' in view2:
+            imgs = [torch.zeros((3,)+hw) for hw in self.imshapes]
+            for v in range(len(self.edges)):
+                idx = view1['idx'][v]
+                imgs[idx] = view1['img'][v]
+                idx = view2['idx'][v]
+                imgs[idx] = view2['img'][v]
+            self.imgs = rgb(imgs)
+
+    @property
+    def n_edges(self):
+        return len(self.edges)
+
+    @property
+    def str_edges(self):
+        return [edge_str(i, j) for i, j in self.edges]
+
+    @property
+    def imsizes(self):
+        return [(w, h) for h, w in self.imshapes]
+
+    @property
+    def device(self):
+        return next(iter(self.parameters())).device
+
+    def state_dict(self, trainable=True):
+        all_params = super().state_dict()
+        return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable}
+
+    def load_state_dict(self, data):
+        return super().load_state_dict(self.state_dict(trainable=False) | data)
+
+    def _check_edges(self):
+        indices = sorted({i for edge in self.edges for i in edge})
+        assert indices == list(range(len(indices))), 'bad pair indices: missing values '
+        return len(indices)
+
+    @torch.no_grad()
+    def _compute_img_conf(self, pred1_conf, pred2_conf):
+        im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes])
+        for e, (i, j) in enumerate(self.edges):
+            im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e])
+            im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e])
+        return im_conf
+
+    def get_adaptors(self): # 公式(5)中的σ_e
+        adapt = self.pw_adaptors
+        adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1)  # (scale_xy, scale_xy, scale_z)
+        if self.norm_pw_scale:  # normalize so that the product == 1
+            adapt = adapt - adapt.mean(dim=1, keepdim=True) # 归一化
+        return (adapt / self.pw_break).exp() # TODO gys:公式(5)中的σ_e是什么?
+
+    def _get_poses(self, poses): # self.im_poses 或者 self.pw_poses
+        # normalize rotation
+        Q = poses[:, :4]
+        T = signed_expm1(poses[:, 4:7])
+        RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous()
+        return RT
+
+    def _set_pose(self, poses, idx, R, T=None, scale=None, force=False):
+        # all poses == cam-to-world
+        pose = poses[idx]
+        if not (pose.requires_grad or force):
+            return pose
+
+        if R.shape == (4, 4):
+            assert T is None
+            T = R[:3, 3]
+            R = R[:3, :3]
+
+        if R is not None:
+            pose.data[0:4] = roma.rotmat_to_unitquat(R)
+        if T is not None:
+            pose.data[4:7] = signed_log1p(T / (scale or 1))  # translation is function of scale
+
+        if scale is not None:
+            assert poses.shape[-1] in (8, 13)
+            pose.data[-1] = np.log(float(scale))
+        return pose
+
+    def get_pw_norm_scale_factor(self):
+        if self.norm_pw_scale:
+            # normalize scales so that things cannot go south
+            # we want that exp(scale) ~= self.base_scale
+            return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp()
+        else:
+            return 1  # don't norm scale for known poses
+
+    def get_pw_scale(self):
+        scale = self.pw_poses[:, -1].exp()  # (n_edges,)
+        scale = scale * self.get_pw_norm_scale_factor()
+        return scale
+
+    def get_pw_poses(self):  # cam to world
+        RT = self._get_poses(self.pw_poses)
+        scaled_RT = RT.clone()
+        scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1)  # scale the rotation AND translation
+        return scaled_RT
+
+    def get_masks(self):
+        return [(conf > self.min_conf_thr) for conf in self.im_conf]
+
+    def depth_to_pts3d(self):
+        raise NotImplementedError()
+
+    def get_pts3d(self, raw=False):
+        res = self.depth_to_pts3d()
+        if not raw:
+            res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
+        return res
+
+    def _set_focal(self, idx, focal, force=False):
+        raise NotImplementedError()
+
+    def get_focals(self):
+        raise NotImplementedError()
+
+    def get_known_focal_mask(self):
+        raise NotImplementedError()
+
+    def get_principal_points(self):
+        raise NotImplementedError()
+
+    def get_conf(self, mode=None):
+        trf = self.conf_trf if mode is None else get_conf_trf(mode)
+        return [trf(c) for c in self.im_conf]
+
+    def get_im_poses(self):
+        raise NotImplementedError()
+
+    def _set_depthmap(self, idx, depth, force=False):
+        raise NotImplementedError()
+
+    def get_depthmaps(self, raw=False):
+        raise NotImplementedError()
+
+    @torch.no_grad()
+    def clean_pointcloud(self, tol=0.001, max_bad_conf=0):
+        """ Method: 
+        1) express all 3d points in each camera coordinate frame
+        2) if they're in front of a depthmap --> then lower their confidence
+        """
+        assert 0 <= tol < 1
+        cams = inv(self.get_im_poses())
+        K = self.get_intrinsics()
+        depthmaps = self.get_depthmaps()
+        res = deepcopy(self)
+
+        for i, pts3d in enumerate(self.depth_to_pts3d()):
+            for j in range(self.n_imgs):
+                if i == j:
+                    continue
+
+                # project 3dpts in other view
+                Hi, Wi = self.imshapes[i]
+                Hj, Wj = self.imshapes[j]
+                proj = geotrf(cams[j], pts3d[:Hi*Wi]).reshape(Hi, Wi, 3)
+                proj_depth = proj[:, :, 2]
+                u, v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1)
+
+                # check which points are actually in the visible cone
+                msk_i = (proj_depth > 0) & (0 <= u) & (u < Wj) & (0 <= v) & (v < Hj)
+                msk_j = v[msk_i], u[msk_i]
+
+                # find bad points = those in front but less confident
+                bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j]
+                              ) & (res.im_conf[i][msk_i] < res.im_conf[j][msk_j])
+
+                bad_msk_i = msk_i.clone()
+                bad_msk_i[msk_i] = bad_points
+                res.im_conf[i][bad_msk_i] = res.im_conf[i][bad_msk_i].clip_(max=max_bad_conf)
+
+        return res
+
+    def forward(self, ret_details=False):
+        pw_poses = self.get_pw_poses()  # cam-to-world
+        pw_adapt = self.get_adaptors()
+        proj_pts3d = self.get_pts3d()
+        # pre-compute pixel weights
+        weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()}
+        weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()}
+
+        loss = 0
+        if ret_details:
+            details = -torch.ones((self.n_imgs, self.n_imgs))
+
+        for e, (i, j) in enumerate(self.edges):
+            i_j = edge_str(i, j)
+            # distance in image i and j
+            aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j])
+            aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j])
+            li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean()
+            lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean()
+            loss = loss + li + lj
+
+            if ret_details:
+                details[i, j] = li + lj
+        loss /= self.n_edges  # average over all pairs
+
+        if ret_details:
+            return loss, details
+        return loss
+
+    def compute_global_alignment(self, init=None, niter_PnP=10, **kw):
+        if init is None:
+            pass
+        elif init == 'msp' or init == 'mst':
+            # ==============3.3.Downstream Applications:主要是为3.4. Global Alignment中的公式(5)初始化内外参矩阵和待估计的世界坐标系的坐标============
+            init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP)
+        elif init == 'known_poses':
+            init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr, niter_PnP=niter_PnP)
+        else:
+            raise ValueError(f'bad value for {init=}')
+
+        global_alignment_loop(self, **kw) # 3.4. Global Alignment:梯度下降公式(5)
+
+    @torch.no_grad()
+    def mask_sky(self):
+        res = deepcopy(self)
+        for i in range(self.n_imgs):
+            sky = segment_sky(self.imgs[i])
+            res.im_conf[i][sky] = 0
+        return res
+
+    def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw):
+        viz = SceneViz()
+        if self.imgs is None:
+            colors = np.random.randint(0, 256, size=(self.n_imgs, 3))
+            colors = list(map(tuple, colors.tolist()))
+            for n in range(self.n_imgs):
+                viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n])
+        else:
+            viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks())
+            colors = np.random.randint(256, size=(self.n_imgs, 3))
+
+        # camera poses
+        im_poses = to_numpy(self.get_im_poses())
+        if cam_size is None:
+            cam_size = auto_cam_size(im_poses)
+        viz.add_cameras(im_poses, self.get_focals(), colors=colors,
+                        images=self.imgs, imsizes=self.imsizes, cam_size=cam_size)
+        if show_pw_cams:
+            pw_poses = self.get_pw_poses()
+            viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size)
+
+            if show_pw_pts3d:
+                pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)]
+                viz.add_pointcloud(pts, (128, 0, 128))
+
+        viz.show(**kw)
+        return viz
+
+
+def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6, verbose=False):
+    params = [p for p in net.parameters() if p.requires_grad]
+    if not params:
+        return net
+
+    if verbose:
+        print([name for name, value in net.named_parameters() if value.requires_grad])
+
+    lr_base = lr
+    optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9))
+
+    with tqdm.tqdm(total=niter) as bar:
+        while bar.n < bar.total:
+            t = bar.n / bar.total
+
+            if schedule == 'cosine':
+                lr = cosine_schedule(t, lr_base, lr_min)
+            elif schedule == 'linear':
+                lr = linear_schedule(t, lr_base, lr_min)
+            else:
+                raise ValueError(f'bad lr {schedule=}')
+            adjust_learning_rate_by_lr(optimizer, lr)
+
+            optimizer.zero_grad()
+            loss = net() # 论文中:Global optimization
+            loss.backward()
+            optimizer.step()
+            loss = float(loss)
+            bar.set_postfix_str(f'{lr=:g} loss={loss:g}')
+            if bar.n % 30 == 0:
+                print(' ')
+            bar.update()
diff --git a/dust3r/cloud_opt/commons.py b/dust3r/cloud_opt/commons.py
new file mode 100644
index 0000000000000000000000000000000000000000..052462e766c67282952f6f6e147c4d927e8ce486
--- /dev/null
+++ b/dust3r/cloud_opt/commons.py
@@ -0,0 +1,91 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# utility functions for global alignment
+# --------------------------------------------------------
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+def edge_str(i, j):
+    return f'{i}_{j}'
+
+
+def i_j_ij(ij):
+    return edge_str(*ij), ij
+
+
+def edge_conf(conf_i, conf_j, edge):
+    return float(conf_i[edge].mean() * conf_j[edge].mean())
+    # edge对应的两张图片经dust3r输出的置信度,分别对两张图片所有像素点的置信度取平均值再相乘,作为当前edge的置信度
+
+
+def compute_edge_scores(edges, conf_i, conf_j):# edge对应的两张图片经dust3r会输出两个置信度矩阵,分别对两张图片所有像素点的置信度取平均值再相乘,作为当前edge的置信度
+    return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges}
+
+
+def NoGradParamDict(x):
+    assert isinstance(x, dict)
+    return nn.ParameterDict(x).requires_grad_(False)
+
+
+def get_imshapes(edges, pred_i, pred_j):
+    n_imgs = max(max(e) for e in edges) + 1
+    imshapes = [None] * n_imgs
+    for e, (i, j) in enumerate(edges):
+        shape_i = tuple(pred_i[e].shape[0:2])
+        shape_j = tuple(pred_j[e].shape[0:2])
+        if imshapes[i]:
+            assert imshapes[i] == shape_i, f'incorrect shape for image {i}'
+        if imshapes[j]:
+            assert imshapes[j] == shape_j, f'incorrect shape for image {j}'
+        imshapes[i] = shape_i
+        imshapes[j] = shape_j
+    return imshapes
+
+
+def get_conf_trf(mode):
+    if mode == 'log':
+        def conf_trf(x): return x.log()
+    elif mode == 'sqrt':
+        def conf_trf(x): return x.sqrt()
+    elif mode == 'm1':
+        def conf_trf(x): return x-1
+    elif mode in ('id', 'none'):
+        def conf_trf(x): return x
+    else:
+        raise ValueError(f'bad mode for {mode=}')
+    return conf_trf
+
+
+def l2_dist(a, b, weight):
+    return ((a - b).square().sum(dim=-1) * weight)
+
+
+def l1_dist(a, b, weight):
+    return ((a - b).norm(dim=-1) * weight) # torch.norm()是求范式的损失,默认是第二范式
+
+
+ALL_DISTS = dict(l1=l1_dist, l2=l2_dist)
+
+
+def signed_log1p(x):
+    sign = torch.sign(x)
+    return sign * torch.log1p(torch.abs(x))
+
+
+def signed_expm1(x):
+    sign = torch.sign(x)
+    return sign * torch.expm1(torch.abs(x))
+
+
+def cosine_schedule(t, lr_start, lr_end):
+    assert 0 <= t <= 1
+    return lr_end + (lr_start - lr_end) * (1+np.cos(t * np.pi))/2
+
+
+def linear_schedule(t, lr_start, lr_end):
+    assert 0 <= t <= 1
+    return lr_start + (lr_end - lr_start) * t
diff --git a/dust3r/cloud_opt/init_im_poses.py b/dust3r/cloud_opt/init_im_poses.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a4d265c9657d2fdcf3b7bde0413157332523608
--- /dev/null
+++ b/dust3r/cloud_opt/init_im_poses.py
@@ -0,0 +1,316 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# Initialization functions for global alignment
+# --------------------------------------------------------
+from functools import cache
+
+import numpy as np
+import scipy.sparse as sp
+import torch
+import cv2
+import roma
+from tqdm import tqdm
+
+from dust3r.utils.geometry import geotrf, inv, get_med_dist_between_poses
+from dust3r.post_process import estimate_focal_knowing_depth
+from dust3r.viz import to_numpy
+
+from dust3r.cloud_opt.commons import edge_str, i_j_ij, compute_edge_scores
+
+
+@torch.no_grad()
+def init_from_known_poses(self, niter_PnP=10, min_conf_thr=3):
+    device = self.device
+
+    # indices of known poses
+    nkp, known_poses_msk, known_poses = get_known_poses(self)
+    assert nkp == self.n_imgs, 'not all poses are known'
+
+    # get all focals
+    nkf, _, im_focals = get_known_focals(self)
+    assert nkf == self.n_imgs
+    im_pp = self.get_principal_points()
+
+    best_depthmaps = {}
+    # init all pairwise poses
+    for e, (i, j) in enumerate(tqdm(self.edges)):
+        i_j = edge_str(i, j)
+
+        # find relative pose for this pair
+        P1 = torch.eye(4, device=device)
+        msk = self.conf_i[i_j] > min(min_conf_thr, self.conf_i[i_j].min() - 0.1)
+        _, P2 = fast_pnp(self.pred_j[i_j], float(im_focals[i].mean()),
+                         pp=im_pp[i], msk=msk, device=device, niter_PnP=niter_PnP)
+
+        # align the two predicted camera with the two gt cameras
+        s, R, T = align_multiple_poses(torch.stack((P1, P2)), known_poses[[i, j]])
+        # normally we have known_poses[i] ~= sRT_to_4x4(s,R,T,device) @ P1
+        # and geotrf(sRT_to_4x4(1,R,T,device), s*P2[:3,3])
+        self._set_pose(self.pw_poses, e, R, T, scale=s)
+
+        # remember if this is a good depthmap
+        score = float(self.conf_i[i_j].mean())
+        if score > best_depthmaps.get(i, (0,))[0]:
+            best_depthmaps[i] = score, i_j, s
+
+    # init all image poses
+    for n in range(self.n_imgs):
+        assert known_poses_msk[n]
+        _, i_j, scale = best_depthmaps[n]
+        depth = self.pred_i[i_j][:, :, 2]
+        self._set_depthmap(n, depth * scale)
+
+
+@torch.no_grad()
+def init_minimum_spanning_tree(self, **kw):
+    """ Init all camera poses (image-wise and pairwise poses) given
+        an initial set of pairwise estimations.
+    """
+    device = self.device
+    pts3d, _, im_focals, im_poses = minimum_spanning_tree(self.imshapes, self.edges,
+                                                          self.pred_i, self.pred_j, self.conf_i, self.conf_j, self.im_conf, self.min_conf_thr,
+                                                          device, has_im_poses=self.has_im_poses, **kw)
+
+    return init_from_pts3d(self, pts3d, im_focals, im_poses) # 初始化
+
+
+def init_from_pts3d(self, pts3d, im_focals, im_poses):
+    # init poses
+    nkp, known_poses_msk, known_poses = get_known_poses(self)
+    if nkp == 1: # 0
+        raise NotImplementedError("Would be simpler to just align everything afterwards on the single known pose")
+    elif nkp > 1:
+        # global rigid SE3 alignment
+        s, R, T = align_multiple_poses(im_poses[known_poses_msk], known_poses[known_poses_msk])
+        trf = sRT_to_4x4(s, R, T, device=known_poses.device)
+
+        # rotate everything
+        im_poses = trf @ im_poses
+        im_poses[:, :3, :3] /= s  # undo scaling on the rotation part
+        for img_pts3d in pts3d:
+            img_pts3d[:] = geotrf(trf, img_pts3d)
+
+    # pw_poses:遍历所有的edge,计算每个edge对应的(即输入dust3r的第一张图片的)相机坐标系转成“世界坐标系”的转换矩阵即P_e
+    for e, (i, j) in enumerate(self.edges):
+        i_j = edge_str(i, j)
+        # compute transform that goes from cam to world
+        # pred_i:dust3r输出的第一张图片对应的3D点云
+        s, R, T = rigid_points_registration(self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j]) # 估计每个edge对应的相机坐标系转成世界坐标系的外参矩阵
+        self._set_pose(self.pw_poses, e, R, T, scale=s) # pw_poses *****************
+
+    # TODO gys:s_factor是什么? take into account the scale normalization
+    s_factor = self.get_pw_norm_scale_factor()
+    im_poses[:, :3, 3] *= s_factor  # apply downscaling factorS
+    for img_pts3d in pts3d:
+        img_pts3d *= s_factor
+
+    # init all image poses
+    if self.has_im_poses:
+        for i in range(self.n_imgs):
+            cam2world = im_poses[i]
+            depth = geotrf(inv(cam2world), pts3d[i])[..., 2] # 将世界坐标系的点pts3d[i]转成相机坐标系
+            self._set_depthmap(i, depth)
+            self._set_pose(self.im_poses, i, cam2world) # im_poses ********************
+            if im_focals[i] is not None:
+                self._set_focal(i, im_focals[i])
+
+    print(' init loss =', float(self()))
+
+
+def minimum_spanning_tree(imshapes, edges, pred_i, pred_j, conf_i, conf_j, im_conf, min_conf_thr,
+                          device, has_im_poses=True, niter_PnP=10):
+    n_imgs = len(imshapes)
+    sparse_graph = -dict_to_sparse_graph(compute_edge_scores(map(i_j_ij, edges), conf_i, conf_j)) # 计算置信度,返回一个矩阵,表示两两图片表示的edge的置信度
+    msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo() # 将上面的矩阵转换成最小生成树,因为sparse_graph加了负号,所以这里筛选出来的其实是最大的置信度
+    # 上面找最小生成树的目的是:为每个图片尽量选一个置信度最大的edge,因为每两两图片之间都存在一个edge
+    # temp variable to store 3d points
+    pts3d = [None] * len(imshapes) # 长度为5的空list(输入图片的数量是5)
+
+    todo = sorted(zip(-msp.data, msp.row, msp.col)) # 根据最小生成树选出:平均置信度最大的4个edge(输入图片的数量是5),这4个edge一定包含5张输入图像 ,因为是生成树 # sorted edges
+    im_poses = [None] * n_imgs
+    im_focals = [None] * n_imgs
+
+    # init with strongest edge
+    score, i, j = todo.pop() # 这里的socre是compute_edge_scores函数计算出的置信度
+    print(f' init edge ({i}*,{j}*) {score=}')
+    i_j = edge_str(i, j)
+    pts3d[i] = pred_i[i_j].clone() # 置信度最大的edge对应的两张图片的三维点云(对与所有图片,每两张图片经dust3r都会输出两个三维点云)
+    pts3d[j] = pred_j[i_j].clone()
+    done = {i, j}
+    if has_im_poses: #============选择置信度最高edge中的第一张图片的相机坐标系为世界坐标系==============
+        im_poses[i] = torch.eye(4, device=device) # 4*4的单位矩阵,因为该图片的相机坐标系就是世界坐标系,所以外参矩阵为单位矩阵
+        im_focals[i] = estimate_focal(pred_i[i_j]) # 3.3 估计内参矩阵
+
+    # set initial pointcloud based on pairwise graph
+    msp_edges = [(i, j)]
+    while todo:
+        # each time, predict the next one
+        score, i, j = todo.pop() # pop把list最后一个元素弹出
+
+        if im_focals[i] is None: # 图片i对应的相机内参已经计算过了
+            im_focals[i] = estimate_focal(pred_i[i_j])
+
+        if i in done:
+            print(f' init edge ({i},{j}*) {score=}')
+            assert j not in done
+            # align pred[i] with pts3d[i], and then set j accordingly
+            i_j = edge_str(i, j)
+            s, R, T = rigid_points_registration(pred_i[i_j], pts3d[i], conf=conf_i[i_j]) # 3.3 外参估计,s是sigma;直接调用roma工具包实现的
+            trf = sRT_to_4x4(s, R, T, device) # 存放到4*4的矩阵中,第四行是[0,0,0,1],对应齐次坐标的转换
+            pts3d[j] = geotrf(trf, pred_j[i_j]) # pred_j[i_j]表示dust3r的输出:图片j在i的相机坐标系下的三维点云
+            done.add(j)
+            msp_edges.append((i, j))
+
+            if has_im_poses and im_poses[i] is None:
+                im_poses[i] = sRT_to_4x4(1, R, T, device)
+
+        elif j in done:
+            print(f' init edge ({i}*,{j}) {score=}')
+            assert i not in done
+            i_j = edge_str(i, j)
+            s, R, T = rigid_points_registration(pred_j[i_j], pts3d[j], conf=conf_j[i_j]) # 从pred_j[i_j]转换到 pts3d[j]的外参矩阵
+            trf = sRT_to_4x4(s, R, T, device)
+            pts3d[i] = geotrf(trf, pred_i[i_j]) # 应用估计出的外参矩阵将相机坐标系的点转成世界坐标系
+            done.add(i)
+            msp_edges.append((i, j))
+
+            if has_im_poses and im_poses[i] is None:
+                im_poses[i] = sRT_to_4x4(1, R, T, device)
+        else:
+            # let's try again later
+            todo.insert(0, (score, i, j))
+
+    if has_im_poses:
+        # complete all missing informations
+        pair_scores = list(sparse_graph.values())  # already negative scores: less is best
+        edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[np.argsort(pair_scores)]
+        for i, j in edges_from_best_to_worse.tolist():
+            if im_focals[i] is None:
+                im_focals[i] = estimate_focal(pred_i[edge_str(i, j)])
+
+        for i in range(n_imgs):
+            if im_poses[i] is None:
+                msk = im_conf[i] > min_conf_thr # 使用PnP算法估计外参矩阵
+                res = fast_pnp(pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP)
+                if res:
+                    im_focals[i], im_poses[i] = res
+            if im_poses[i] is None:
+                im_poses[i] = torch.eye(4, device=device)
+        im_poses = torch.stack(im_poses)
+    else:
+        im_poses = im_focals = None
+
+    return pts3d, msp_edges, im_focals, im_poses # pts3d表示:每个输入的图片在自己的相机坐标系下的三维点经im_poses转换成世界坐标系的点
+
+
+def dict_to_sparse_graph(dic):
+    n_imgs = max(max(e) for e in dic) + 1 # 取出照片数量
+    for e in dic:
+        a1 = max(e)
+        a2 = 2
+    res = sp.dok_array((n_imgs, n_imgs))
+    for edge, value in dic.items():
+        res[edge] = value
+    return res # 将edge中存放的置信度转移到一个n_imgs * n_imgs大小的列表中
+
+
+def rigid_points_registration(pts1, pts2, conf):
+    R, T, s = roma.rigid_points_registration( # 调用roma的工具类函数
+        pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf.ravel(), compute_scaling=True)
+    return s, R, T  # return un-scaled (R, T)
+
+
+def sRT_to_4x4(scale, R, T, device):
+    trf = torch.eye(4, device=device) # 单位矩阵
+    trf[:3, :3] = R * scale
+    trf[:3, 3] = T.ravel()  # doesn't need scaling
+    return trf # 外参矩阵 3*4
+
+
+def estimate_focal(pts3d_i, pp=None):
+    if pp is None:
+        H, W, THREE = pts3d_i.shape
+        assert THREE == 3
+        pp = torch.tensor((W/2, H/2), device=pts3d_i.device)
+    focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(
+        0), focal_mode='weiszfeld', min_focal=0.5, max_focal=3.5).ravel()
+    return float(focal)
+
+
+@cache
+def pixel_grid(H, W):
+    return np.mgrid[:W, :H].T.astype(np.float32)
+
+
+def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
+    # extract camera poses and focals with RANSAC-PnP
+    if msk.sum() < 4:
+        return None  # we need at least 4 points for PnP
+    pts3d, msk = map(to_numpy, (pts3d, msk))
+
+    H, W, THREE = pts3d.shape
+    assert THREE == 3
+    pixels = pixel_grid(H, W)
+
+    if focal is None:
+        S = max(W, H)
+        tentative_focals = np.geomspace(S/2, S*3, 21)
+    else:
+        tentative_focals = [focal]
+
+    if pp is None:
+        pp = (W/2, H/2)
+    else:
+        pp = to_numpy(pp)
+
+    best = 0,
+    for focal in tentative_focals:
+        K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
+
+        success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
+                                                    iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
+        if not success:
+            continue
+
+        score = len(inliers)
+        if success and score > best[0]:
+            best = score, R, T, focal
+
+    if not best[0]:
+        return None
+
+    _, R, T, best_focal = best
+    R = cv2.Rodrigues(R)[0]  # world to cam
+    R, T = map(torch.from_numpy, (R, T))
+    return best_focal, inv(sRT_to_4x4(1, R, T, device))  # cam to world
+
+
+def get_known_poses(self):
+    if self.has_im_poses:
+        known_poses_msk = torch.tensor([not (p.requires_grad) for p in self.im_poses])
+        known_poses = self.get_im_poses()
+        return known_poses_msk.sum(), known_poses_msk, known_poses
+    else:
+        return 0, None, None
+
+
+def get_known_focals(self):
+    if self.has_im_poses:
+        known_focal_msk = self.get_known_focal_mask()
+        known_focals = self.get_focals()
+        return known_focal_msk.sum(), known_focal_msk, known_focals
+    else:
+        return 0, None, None
+
+
+def align_multiple_poses(src_poses, target_poses):
+    N = len(src_poses)
+    assert src_poses.shape == target_poses.shape == (N, 4, 4)
+
+    def center_and_z(poses):
+        eps = get_med_dist_between_poses(poses) / 100
+        return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps*poses[:, :3, 2]))
+    R, T, s = roma.rigid_points_registration(center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True)
+    return s, R, T
diff --git a/dust3r/cloud_opt/optimizer.py b/dust3r/cloud_opt/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f778d45d2f1dd781c415b6a6c005f5f08170d743
--- /dev/null
+++ b/dust3r/cloud_opt/optimizer.py
@@ -0,0 +1,249 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# Main class for the implementation of the global alignment
+# --------------------------------------------------------
+import numpy as np
+import torch
+import torch.nn as nn
+
+from dust3r.cloud_opt.base_opt import BasePCOptimizer
+from dust3r.utils.geometry import xy_grid, geotrf
+from dust3r.utils.device import to_cpu, to_numpy
+
+
+class PointCloudOptimizer(BasePCOptimizer):
+    """ Optimize a global scene, given a list of pairwise observations.
+    Graph node: images
+    Graph edges: observations = (pred1, pred2)
+    """
+
+    def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.has_im_poses = True  # by definition of this class
+        self.focal_break = focal_break
+
+        # adding thing to optimize
+        self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes)  # log(depth)
+        self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs))  # camera poses
+        self.im_focals = nn.ParameterList(torch.FloatTensor(
+            [self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes)  # camera intrinsics
+        self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs))  # camera intrinsics
+        self.im_pp.requires_grad_(optimize_pp)
+
+        self.imshape = self.imshapes[0]
+        im_areas = [h*w for h, w in self.imshapes]
+
+
+        self.max_area = max(im_areas)
+
+
+        # adding thing to optimize
+        self.im_depthmaps = ParameterStack(self.im_depthmaps, is_param=True, fill=self.max_area)
+        self.im_poses = ParameterStack(self.im_poses, is_param=True)
+        self.im_focals = ParameterStack(self.im_focals, is_param=True)
+        self.im_pp = ParameterStack(self.im_pp, is_param=True)
+        self.register_buffer('_pp', torch.tensor([(w/2, h/2) for h, w in self.imshapes]))
+        self.register_buffer('_grid', ParameterStack(
+            [xy_grid(W, H, device=self.device) for H, W in self.imshapes], fill=self.max_area))
+
+        # pre-compute pixel weights
+        self.register_buffer('_weight_i', ParameterStack(
+            [self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], fill=self.max_area))
+        self.register_buffer('_weight_j', ParameterStack(
+            [self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], fill=self.max_area))
+
+        # precompute
+        self.register_buffer('_stacked_pred_i', ParameterStack(self.pred_i, self.str_edges, fill=self.max_area))
+        self.register_buffer('_stacked_pred_j', ParameterStack(self.pred_j, self.str_edges, fill=self.max_area))
+        self.register_buffer('_ei', torch.tensor([i for i, j in self.edges]))
+        self.register_buffer('_ej', torch.tensor([j for i, j in self.edges]))
+        self.total_area_i = sum([im_areas[i] for i, j in self.edges])
+        self.total_area_j = sum([im_areas[j] for i, j in self.edges])
+
+
+    def _check_all_imgs_are_selected(self, msk):
+        assert np.all(self._get_msk_indices(msk) == np.arange(self.n_imgs)), 'incomplete mask!'
+
+    def preset_pose(self, known_poses, pose_msk=None):  # cam-to-world
+        self._check_all_imgs_are_selected(pose_msk)
+
+        if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:
+            known_poses = [known_poses]
+        for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):
+            print(f' (setting pose #{idx} = {pose[:3,3]})')
+            self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose)))
+
+        # normalize scale if there's less than 1 known pose
+        n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)
+        self.norm_pw_scale = (n_known_poses <= 1)
+
+        self.im_poses.requires_grad_(False)
+        self.norm_pw_scale = False
+
+    def preset_focal(self, known_focals, msk=None):
+        self._check_all_imgs_are_selected(msk)
+
+        for idx, focal in zip(self._get_msk_indices(msk), known_focals):
+            print(f' (setting focal #{idx} = {focal})')
+            self._no_grad(self._set_focal(idx, focal))
+
+        self.im_focals.requires_grad_(False)
+
+    def preset_principal_point(self, known_pp, msk=None):
+        self._check_all_imgs_are_selected(msk)
+
+        for idx, pp in zip(self._get_msk_indices(msk), known_pp):
+            print(f' (setting principal point #{idx} = {pp})')
+            self._no_grad(self._set_principal_point(idx, pp))
+
+        self.im_pp.requires_grad_(False)
+
+    def _get_msk_indices(self, msk):
+        if msk is None:
+            return range(self.n_imgs)
+        elif isinstance(msk, int):
+            return [msk]
+        elif isinstance(msk, (tuple, list)):
+            return self._get_msk_indices(np.array(msk))
+        elif msk.dtype in (bool, torch.bool, np.bool_):
+            assert len(msk) == self.n_imgs
+            return np.cumsum([0] + msk.tolist())
+        elif np.issubdtype(msk.dtype, np.integer):
+            return msk
+        else:
+            raise ValueError(f'bad {msk=}')
+
+    def _no_grad(self, tensor):
+        assert tensor.requires_grad, 'it must be True at this point, otherwise no modification occurs'
+
+    def _set_focal(self, idx, focal, force=False):
+        param = self.im_focals[idx]
+        if param.requires_grad or force:  # can only init a parameter not already initialized
+            param.data[:] = self.focal_break * np.log(focal)
+        return param
+
+    def get_focals(self): # 论文中Recovering intrinsics章节:求内参矩阵(即焦距)
+        log_focals = torch.stack(list(self.im_focals), dim=0)
+        return (log_focals / self.focal_break).exp()
+
+    def get_known_focal_mask(self):
+        return torch.tensor([not (p.requires_grad) for p in self.im_focals])
+
+    def _set_principal_point(self, idx, pp, force=False):
+        param = self.im_pp[idx]
+        H, W = self.imshapes[idx]
+        if param.requires_grad or force:  # can only init a parameter not already initialized
+            param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10
+        return param
+
+    def get_principal_points(self):
+        return self._pp + 10 * self.im_pp # 将图像坐标系和像素坐标系的中心点偏移量
+
+    def get_intrinsics(self):
+        K = torch.zeros((self.n_imgs, 3, 3), device=self.device)
+        focals = self.get_focals().flatten()
+        K[:, 0, 0] = K[:, 1, 1] = focals
+        K[:, :2, 2] = self.get_principal_points()
+        K[:, 2, 2] = 1
+        return K
+
+    def get_im_poses(self):  # cam to world 外参数矩阵的逆
+        cam2world = self._get_poses(self.im_poses)
+        return cam2world
+
+    def _set_depthmap(self, idx, depth, force=False):
+        depth = _ravel_hw(depth, self.max_area)
+
+        param = self.im_depthmaps[idx]
+        if param.requires_grad or force:  # can only init a parameter not already initialized
+            param.data[:] = depth.log().nan_to_num(neginf=0)
+        return param
+
+    def get_depthmaps(self, raw=False): #论文中公式(1)上面的的深度信息D
+        res = self.im_depthmaps.exp()
+        if not raw:
+            res = [dm[:h*w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)]
+        return res
+
+    def depth_to_pts3d(self): # 这里根据深度信息D计算真实的世界坐标系下的点,即论文中公式(1)上面的公式
+        # Get depths and  projection params if not provided
+        focals = self.get_focals() # 论文中Recovering intrinsics章节:求内参矩阵(即焦距)
+        pp = self.get_principal_points() # 图像坐标系和像素坐标系之间的偏移,即照片宽高的一半
+        im_poses = self.get_im_poses() # 外参数矩阵
+        depth = self.get_depthmaps(raw=True)#论文中公式(1)上面的深度信息D
+
+        # get pointmaps in camera frame self._grid:输入的所有图像(图像坐标系)
+        rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp) # 将输入图像的坐标点转成相机坐标系下的点
+        # project to world frame
+        return geotrf(im_poses, rel_ptmaps) # 再由相机坐标系转成世界坐标系
+
+    def get_pts3d(self, raw=False): # 计算真实的世界坐标系下的三维点坐标,根据公式(1)上面的深度D计算公式计算
+        res = self.depth_to_pts3d()
+        if not raw:
+            res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
+        return res
+    # 这里的forward返回的就是公式(5)计算的损失值
+    def forward(self): # 论文中: Global optimization
+        pw_poses = self.get_pw_poses() # pw_poses cam-to-world 公式(5)的P_e: 外参矩阵的逆,由相机坐标系转成世界坐标系,requires_grad=True
+        pw_adapt = self.get_adaptors().unsqueeze(1) # 公式(5)中的比例系数 sigma,requires_grad=False
+        proj_pts3d = self.get_pts3d(raw=True) # im_poses 公式(5)的待优化的真实的世界坐标系下的三维点requires_grad=True
+
+        # rotate pairwise prediction according to pw_poses 根据公式(5)的外参矩阵部分转成世界坐标系requires_grad=True
+        aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i) # _stacked_pred_i/j表示dest3r预测的三维点云, requires_grad=False
+        aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j)
+
+        # compute the loss: 转换成世界坐标系后的两张图像分别与待估计世界坐标系下的点(proj_pts3d)计算损失
+        li = self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() / self.total_area_i
+        lj = self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() / self.total_area_j
+
+        return li + lj
+
+
+def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp):
+    pp = pp.unsqueeze(1)
+    focal = focal.unsqueeze(1)
+    assert focal.shape == (len(depth), 1, 1)
+    assert pp.shape == (len(depth), 1, 2)
+    assert pixel_grid.shape == depth.shape + (2,)
+    depth = depth.unsqueeze(-1)
+    return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1) # 公式(1)上面的计算公式,根据内参矩阵和深度D,将图像坐标系的点转成相机坐标系下的三维点
+
+
+def ParameterStack(params, keys=None, is_param=None, fill=0):
+    if keys is not None:
+        params = [params[k] for k in keys]
+
+    if fill > 0:
+        params = [_ravel_hw(p, fill) for p in params]
+
+    requires_grad = params[0].requires_grad
+    assert all(p.requires_grad == requires_grad for p in params)
+
+    params = torch.stack(list(params)).float().detach()
+    if is_param or requires_grad:
+        params = nn.Parameter(params)
+        params.requires_grad_(requires_grad)
+    return params
+
+
+def _ravel_hw(tensor, fill=0):
+    # ravel H,W
+    tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
+
+    if len(tensor) < fill:
+        tensor = torch.cat((tensor, tensor.new_zeros((fill - len(tensor),)+tensor.shape[1:])))
+    return tensor
+
+
+def acceptable_focal_range(H, W, minf=0.5, maxf=3.5):
+    focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2))  # size / 1.1547005383792515
+    return minf*focal_base, maxf*focal_base
+
+
+def apply_mask(img, msk):
+    img = img.copy()
+    img[msk] = 0
+    return img
diff --git a/dust3r/cloud_opt/pair_viewer.py b/dust3r/cloud_opt/pair_viewer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a49e9a17df9ddc489b8fe3dddc027636c0c5973d
--- /dev/null
+++ b/dust3r/cloud_opt/pair_viewer.py
@@ -0,0 +1,125 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# Dummy optimizer for visualizing pairs
+# --------------------------------------------------------
+import numpy as np
+import torch
+import torch.nn as nn
+import cv2
+
+from dust3r.cloud_opt.base_opt import BasePCOptimizer
+from dust3r.utils.geometry import inv, geotrf, depthmap_to_absolute_camera_coordinates
+from dust3r.cloud_opt.commons import edge_str
+from dust3r.post_process import estimate_focal_knowing_depth
+
+
+class PairViewer (BasePCOptimizer):
+    """
+    This a Dummy Optimizer.
+    To use only when the goal is to visualize the results for a pair of images (with is_symmetrized)
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        assert self.is_symmetrized and self.n_edges == 2
+        self.has_im_poses = True
+
+        # compute all parameters directly from raw input
+        self.focals = []
+        self.pp = []
+        rel_poses = []
+        confs = []
+        for i in range(self.n_imgs):
+            conf = float(self.conf_i[edge_str(i, 1-i)].mean() * self.conf_j[edge_str(i, 1-i)].mean())
+            print(f'  - {conf=:.3} for edge {i}-{1-i}')
+            confs.append(conf)
+
+            H, W = self.imshapes[i]
+            pts3d = self.pred_i[edge_str(i, 1-i)]
+            pp = torch.tensor((W/2, H/2))
+            focal = float(estimate_focal_knowing_depth(pts3d[None], pp, focal_mode='weiszfeld'))
+            self.focals.append(focal)
+            self.pp.append(pp)
+
+            # estimate the pose of pts1 in image 2
+            pixels = np.mgrid[:W, :H].T.astype(np.float32)
+            pts3d = self.pred_j[edge_str(1-i, i)].numpy()
+            assert pts3d.shape[:2] == (H, W)
+            msk = self.get_masks()[i].numpy()
+            K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
+
+            try:
+                res = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
+                                         iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
+                success, R, T, inliers = res
+                assert success
+
+                R = cv2.Rodrigues(R)[0]  # world to cam
+                pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]])  # cam to world
+            except:
+                pose = np.eye(4)
+            rel_poses.append(torch.from_numpy(pose.astype(np.float32)))
+
+        # let's use the pair with the most confidence
+        if confs[0] > confs[1]:
+            # ptcloud is expressed in camera1
+            self.im_poses = [torch.eye(4), rel_poses[1]]  # I, cam2-to-cam1
+            self.depth = [self.pred_i['0_1'][..., 2], geotrf(inv(rel_poses[1]), self.pred_j['0_1'])[..., 2]]
+        else:
+            # ptcloud is expressed in camera2
+            self.im_poses = [rel_poses[0], torch.eye(4)]  # I, cam1-to-cam2
+            self.depth = [geotrf(inv(rel_poses[0]), self.pred_j['1_0'])[..., 2], self.pred_i['1_0'][..., 2]]
+
+        self.im_poses = nn.Parameter(torch.stack(self.im_poses, dim=0), requires_grad=False)
+        self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False)
+        self.pp = nn.Parameter(torch.stack(self.pp, dim=0), requires_grad=False)
+        self.depth = nn.ParameterList(self.depth)
+        for p in self.parameters():
+            p.requires_grad = False
+
+    def _set_depthmap(self, idx, depth, force=False):
+        print('_set_depthmap is ignored in PairViewer')
+        return
+
+    def get_depthmaps(self, raw=False):
+        depth = [d.to(self.device) for d in self.depth]
+        return depth
+
+    def _set_focal(self, idx, focal, force=False):
+        self.focals[idx] = focal
+
+    def get_focals(self):
+        return self.focals
+
+    def get_known_focal_mask(self):
+        return torch.tensor([not (p.requires_grad) for p in self.focals])
+
+    def get_principal_points(self):
+        return self.pp
+
+    def get_intrinsics(self):
+        focals = self.get_focals()
+        pps = self.get_principal_points()
+        K = torch.zeros((len(focals), 3, 3), device=self.device)
+        for i in range(len(focals)):
+            K[i, 0, 0] = K[i, 1, 1] = focals[i]
+            K[i, :2, 2] = pps[i]
+            K[i, 2, 2] = 1
+        return K
+
+    def get_im_poses(self):
+        return self.im_poses
+
+    def depth_to_pts3d(self):
+        pts3d = []
+        for d, intrinsics, im_pose in zip(self.depth, self.get_intrinsics(), self.get_im_poses()):
+            pts, _ = depthmap_to_absolute_camera_coordinates(d.cpu().numpy(),
+                                                             intrinsics.cpu().numpy(),
+                                                             im_pose.cpu().numpy())
+            pts3d.append(torch.from_numpy(pts).to(device=self.device))
+        return pts3d
+
+    def forward(self):
+        return float('nan')
diff --git a/dust3r/datasets/__init__.py b/dust3r/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc5e79718e4a3eb2e31c60c8a390e61a19ec5432
--- /dev/null
+++ b/dust3r/datasets/__init__.py
@@ -0,0 +1,42 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+from .utils.transforms import *
+from .base.batched_sampler import BatchedRandomSampler  # noqa: F401
+from .co3d import Co3d  # noqa: F401
+
+
+def get_data_loader(dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True):
+    import torch
+    from croco.utils.misc import get_world_size, get_rank
+
+    # pytorch dataset
+    if isinstance(dataset, str):
+        dataset = eval(dataset)
+
+    world_size = get_world_size()
+    rank = get_rank()
+
+    try:
+        sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size,
+                                       rank=rank, drop_last=drop_last)
+    except (AttributeError, NotImplementedError):
+        # not avail for this dataset
+        if torch.distributed.is_initialized():
+            sampler = torch.utils.data.DistributedSampler(
+                dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last
+            )
+        elif shuffle:
+            sampler = torch.utils.data.RandomSampler(dataset)
+        else:
+            sampler = torch.utils.data.SequentialSampler(dataset)
+
+    data_loader = torch.utils.data.DataLoader(
+        dataset,
+        sampler=sampler,
+        batch_size=batch_size,
+        num_workers=num_workers,
+        pin_memory=pin_mem,
+        drop_last=drop_last,
+    )
+
+    return data_loader
diff --git a/dust3r/datasets/base/__init__.py b/dust3r/datasets/base/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e
--- /dev/null
+++ b/dust3r/datasets/base/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
diff --git a/dust3r/datasets/base/base_stereo_view_dataset.py b/dust3r/datasets/base/base_stereo_view_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..17390ca29d4437fc41f3c946b235888af9e4c888
--- /dev/null
+++ b/dust3r/datasets/base/base_stereo_view_dataset.py
@@ -0,0 +1,220 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# base class for implementing datasets
+# --------------------------------------------------------
+import PIL
+import numpy as np
+import torch
+
+from dust3r.datasets.base.easy_dataset import EasyDataset
+from dust3r.datasets.utils.transforms import ImgNorm
+from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates
+import dust3r.datasets.utils.cropping as cropping
+
+
+class BaseStereoViewDataset (EasyDataset):
+    """ Define all basic options.
+
+    Usage:
+        class MyDataset (BaseStereoViewDataset):
+            def _get_views(self, idx, rng):
+                # overload here
+                views = []
+                views.append(dict(img=, ...))
+                return views
+    """
+
+    def __init__(self, *,  # only keyword arguments
+                 split=None,
+                 resolution=None,  # square_size or (width, height) or list of [(width,height), ...]
+                 transform=ImgNorm,
+                 aug_crop=False,
+                 seed=None):
+        self.num_views = 2
+        self.split = split
+        self._set_resolutions(resolution)
+
+        self.transform = transform
+        if isinstance(transform, str):
+            transform = eval(transform)
+
+        self.aug_crop = aug_crop
+        self.seed = seed
+
+    def __len__(self):
+        return len(self.scenes)
+
+    def get_stats(self):
+        return f"{len(self)} pairs"
+
+    def __repr__(self):
+        resolutions_str = '['+';'.join(f'{w}x{h}' for w, h in self._resolutions)+']'
+        return f"""{type(self).__name__}({self.get_stats()},
+            {self.split=},
+            {self.seed=},
+            resolutions={resolutions_str},
+            {self.transform=})""".replace('self.', '').replace('\n', '').replace('   ', '')
+
+    def _get_views(self, idx, resolution, rng):
+        raise NotImplementedError()
+
+    def __getitem__(self, idx):
+        if isinstance(idx, tuple):
+            # the idx is specifying the aspect-ratio
+            idx, ar_idx = idx
+        else:
+            assert len(self._resolutions) == 1
+            ar_idx = 0
+
+        # set-up the rng
+        if self.seed:  # reseed for each __getitem__
+            self._rng = np.random.default_rng(seed=self.seed + idx)
+        elif not hasattr(self, '_rng'):
+            seed = torch.initial_seed()  # this is different for each dataloader process
+            self._rng = np.random.default_rng(seed=seed)
+
+        # over-loaded code
+        resolution = self._resolutions[ar_idx]  # DO NOT CHANGE THIS (compatible with BatchedRandomSampler)
+        views = self._get_views(idx, resolution, self._rng)
+        assert len(views) == self.num_views
+
+        # check data-types
+        for v, view in enumerate(views):
+            assert 'pts3d' not in view, f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
+            view['idx'] = (idx, ar_idx, v)
+
+            # encode the image
+            width, height = view['img'].size
+            view['true_shape'] = np.int32((height, width))
+            view['img'] = self.transform(view['img'])
+
+            assert 'camera_intrinsics' in view
+            if 'camera_pose' not in view:
+                view['camera_pose'] = np.full((4, 4), np.nan, dtype=np.float32)
+            else:
+                assert np.isfinite(view['camera_pose']).all(), f'NaN in camera pose for view {view_name(view)}'
+            assert 'pts3d' not in view
+            assert 'valid_mask' not in view
+            assert np.isfinite(view['depthmap']).all(), f'NaN in depthmap for view {view_name(view)}'
+            pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view)
+
+            view['pts3d'] = pts3d
+            view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1)
+
+            # check all datatypes
+            for key, val in view.items():
+                res, err_msg = is_good_type(key, val)
+                assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
+            K = view['camera_intrinsics']
+
+        # last thing done!
+        for view in views:
+            # transpose to make sure all views are the same size
+            transpose_to_landscape(view)
+            # this allows to check whether the RNG is is the same state each time
+            view['rng'] = int.from_bytes(self._rng.bytes(4), 'big')
+        return views
+
+    def _set_resolutions(self, resolutions):
+        assert resolutions is not None, 'undefined resolution'
+
+        if not isinstance(resolutions, list):
+            resolutions = [resolutions]
+
+        self._resolutions = []
+        for resolution in resolutions:
+            if isinstance(resolution, int):
+                width = height = resolution
+            else:
+                width, height = resolution
+            assert isinstance(width, int), f'Bad type for {width=} {type(width)=}, should be int'
+            assert isinstance(height, int), f'Bad type for {height=} {type(height)=}, should be int'
+            assert width >= height
+            self._resolutions.append((width, height))
+
+    def _crop_resize_if_necessary(self, image, depthmap, intrinsics, resolution, rng=None, info=None):
+        """ This function:
+            - first downsizes the image with LANCZOS inteprolation,
+              which is better than bilinear interpolation in
+        """
+        if not isinstance(image, PIL.Image.Image):
+            image = PIL.Image.fromarray(image)
+
+        # downscale with lanczos interpolation so that image.size == resolution
+        # cropping centered on the principal point
+        W, H = image.size
+        cx, cy = intrinsics[:2, 2].round().astype(int)
+        min_margin_x = min(cx, W-cx)
+        min_margin_y = min(cy, H-cy)
+        assert min_margin_x > W/5, f'Bad principal point in view={info}'
+        assert min_margin_y > H/5, f'Bad principal point in view={info}'
+        # the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy)
+        l, t = cx - min_margin_x, cy - min_margin_y
+        r, b = cx + min_margin_x, cy + min_margin_y
+        crop_bbox = (l, t, r, b)
+        image, depthmap, intrinsics = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
+
+        # transpose the resolution if necessary
+        W, H = image.size  # new size
+        assert resolution[0] >= resolution[1]
+        if H > 1.1*W:
+            # image is portrait mode
+            resolution = resolution[::-1]
+        elif 0.9 < H/W < 1.1 and resolution[0] != resolution[1]:
+            # image is square, so we chose (portrait, landscape) randomly
+            if rng.integers(2):
+                resolution = resolution[::-1]
+
+        # high-quality Lanczos down-scaling
+        target_resolution = np.array(resolution)
+        if self.aug_crop > 1:
+            target_resolution += rng.integers(0, self.aug_crop)
+        image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution)
+
+        # actual cropping (if necessary) with bilinear interpolation
+        intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=0.5)
+        crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution)
+        image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox)
+
+        return image, depthmap, intrinsics2
+
+
+def is_good_type(key, v):
+    """ returns (is_good, err_msg) 
+    """
+    if isinstance(v, (str, int, tuple)):
+        return True, None
+    if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
+        return False, f"bad {v.dtype=}"
+    return True, None
+
+
+def view_name(view, batch_index=None):
+    def sel(x): return x[batch_index] if batch_index not in (None, slice(None)) else x
+    db = sel(view['dataset'])
+    label = sel(view['label'])
+    instance = sel(view['instance'])
+    return f"{db}/{label}/{instance}"
+
+
+def transpose_to_landscape(view):
+    height, width = view['true_shape']
+
+    if width < height:
+        # rectify portrait to landscape
+        assert view['img'].shape == (3, height, width)
+        view['img'] = view['img'].swapaxes(1, 2)
+
+        assert view['valid_mask'].shape == (height, width)
+        view['valid_mask'] = view['valid_mask'].swapaxes(0, 1)
+
+        assert view['depthmap'].shape == (height, width)
+        view['depthmap'] = view['depthmap'].swapaxes(0, 1)
+
+        assert view['pts3d'].shape == (height, width, 3)
+        view['pts3d'] = view['pts3d'].swapaxes(0, 1)
+
+        # transpose x and y pixels
+        view['camera_intrinsics'] = view['camera_intrinsics'][[1, 0, 2]]
diff --git a/dust3r/datasets/base/batched_sampler.py b/dust3r/datasets/base/batched_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..85f58a65d41bb8101159e032d5b0aac26a7cf1a1
--- /dev/null
+++ b/dust3r/datasets/base/batched_sampler.py
@@ -0,0 +1,74 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# Random sampling under a constraint
+# --------------------------------------------------------
+import numpy as np
+import torch
+
+
+class BatchedRandomSampler:
+    """ Random sampling under a constraint: each sample in the batch has the same feature, 
+    which is chosen randomly from a known pool of 'features' for each batch.
+
+    For instance, the 'feature' could be the image aspect-ratio.
+
+    The index returned is a tuple (sample_idx, feat_idx).
+    This sampler ensures that each series of `batch_size` indices has the same `feat_idx`.
+    """
+
+    def __init__(self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True):
+        self.batch_size = batch_size
+        self.pool_size = pool_size
+
+        self.len_dataset = N = len(dataset)
+        self.total_size = round_by(N, batch_size*world_size) if drop_last else N
+        assert world_size == 1 or drop_last, 'must drop the last batch in distributed mode'
+
+        # distributed sampler
+        self.world_size = world_size
+        self.rank = rank
+        self.epoch = None
+
+    def __len__(self):
+        return self.total_size // self.world_size
+
+    def set_epoch(self, epoch):
+        self.epoch = epoch
+
+    def __iter__(self):
+        # prepare RNG
+        if self.epoch is None:
+            assert self.world_size == 1 and self.rank == 0, 'use set_epoch() if distributed mode is used'
+            seed = int(torch.empty((), dtype=torch.int64).random_().item())
+        else:
+            seed = self.epoch + 777
+        rng = np.random.default_rng(seed=seed)
+
+        # random indices (will restart from 0 if not drop_last)
+        sample_idxs = np.arange(self.total_size)
+        rng.shuffle(sample_idxs)
+
+        # random feat_idxs (same across each batch)
+        n_batches = (self.total_size+self.batch_size-1) // self.batch_size
+        feat_idxs = rng.integers(self.pool_size, size=n_batches)
+        feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size))
+        feat_idxs = feat_idxs.ravel()[:self.total_size]
+
+        # put them together
+        idxs = np.c_[sample_idxs, feat_idxs]  # shape = (total_size, 2)
+
+        # Distributed sampler: we select a subset of batches
+        # make sure the slice for each node is aligned with batch_size
+        size_per_proc = self.batch_size * ((self.total_size + self.world_size *
+                                           self.batch_size-1) // (self.world_size * self.batch_size))
+        idxs = idxs[self.rank*size_per_proc: (self.rank+1)*size_per_proc]
+
+        yield from (tuple(idx) for idx in idxs)
+
+
+def round_by(total, multiple, up=False):
+    if up:
+        total = total + multiple-1
+    return (total//multiple) * multiple
diff --git a/dust3r/datasets/base/easy_dataset.py b/dust3r/datasets/base/easy_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4939a88f02715a1f80be943ddb6d808e1be84db7
--- /dev/null
+++ b/dust3r/datasets/base/easy_dataset.py
@@ -0,0 +1,157 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# A dataset base class that you can easily resize and combine.
+# --------------------------------------------------------
+import numpy as np
+from dust3r.datasets.base.batched_sampler import BatchedRandomSampler
+
+
+class EasyDataset:
+    """ a dataset that you can easily resize and combine.
+    Examples:
+    ---------
+        2 * dataset ==> duplicate each element 2x
+
+        10 @ dataset ==> set the size to 10 (random sampling, duplicates if necessary)
+
+        dataset1 + dataset2 ==> concatenate datasets
+    """
+
+    def __add__(self, other):
+        return CatDataset([self, other])
+
+    def __rmul__(self, factor):
+        return MulDataset(factor, self)
+
+    def __rmatmul__(self, factor):
+        return ResizedDataset(factor, self)
+
+    def set_epoch(self, epoch):
+        pass  # nothing to do by default
+
+    def make_sampler(self, batch_size, shuffle=True, world_size=1, rank=0, drop_last=True):
+        if not (shuffle):
+            raise NotImplementedError()  # cannot deal yet
+        num_of_aspect_ratios = len(self._resolutions)
+        return BatchedRandomSampler(self, batch_size, num_of_aspect_ratios, world_size=world_size, rank=rank, drop_last=drop_last)
+
+
+class MulDataset (EasyDataset):
+    """ Artifically augmenting the size of a dataset.
+    """
+    multiplicator: int
+
+    def __init__(self, multiplicator, dataset):
+        assert isinstance(multiplicator, int) and multiplicator > 0
+        self.multiplicator = multiplicator
+        self.dataset = dataset
+
+    def __len__(self):
+        return self.multiplicator * len(self.dataset)
+
+    def __repr__(self):
+        return f'{self.multiplicator}*{repr(self.dataset)}'
+
+    def __getitem__(self, idx):
+        if isinstance(idx, tuple):
+            idx, other = idx
+            return self.dataset[idx // self.multiplicator, other]
+        else:
+            return self.dataset[idx // self.multiplicator]
+
+    @property
+    def _resolutions(self):
+        return self.dataset._resolutions
+
+
+class ResizedDataset (EasyDataset):
+    """ Artifically changing the size of a dataset.
+    """
+    new_size: int
+
+    def __init__(self, new_size, dataset):
+        assert isinstance(new_size, int) and new_size > 0
+        self.new_size = new_size
+        self.dataset = dataset
+
+    def __len__(self):
+        return self.new_size
+
+    def __repr__(self):
+        size_str = str(self.new_size)
+        for i in range((len(size_str)-1) // 3):
+            sep = -4*i-3
+            size_str = size_str[:sep] + '_' + size_str[sep:]
+        return f'{size_str} @ {repr(self.dataset)}'
+
+    def set_epoch(self, epoch):
+        # this random shuffle only depends on the epoch
+        rng = np.random.default_rng(seed=epoch+777)
+
+        # shuffle all indices
+        perm = rng.permutation(len(self.dataset))
+
+        # rotary extension until target size is met
+        shuffled_idxs = np.concatenate([perm] * (1 + (len(self)-1) // len(self.dataset)))
+        self._idxs_mapping = shuffled_idxs[:self.new_size]
+
+        assert len(self._idxs_mapping) == self.new_size
+
+    def __getitem__(self, idx):
+        assert hasattr(self, '_idxs_mapping'), 'You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()'
+        if isinstance(idx, tuple):
+            idx, other = idx
+            return self.dataset[self._idxs_mapping[idx], other]
+        else:
+            return self.dataset[self._idxs_mapping[idx]]
+
+    @property
+    def _resolutions(self):
+        return self.dataset._resolutions
+
+
+class CatDataset (EasyDataset):
+    """ Concatenation of several datasets 
+    """
+
+    def __init__(self, datasets):
+        for dataset in datasets:
+            assert isinstance(dataset, EasyDataset)
+        self.datasets = datasets
+        self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets])
+
+    def __len__(self):
+        return self._cum_sizes[-1]
+
+    def __repr__(self):
+        # remove uselessly long transform
+        return ' + '.join(repr(dataset).replace(',transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))', '') for dataset in self.datasets)
+
+    def set_epoch(self, epoch):
+        for dataset in self.datasets:
+            dataset.set_epoch(epoch)
+
+    def __getitem__(self, idx):
+        other = None
+        if isinstance(idx, tuple):
+            idx, other = idx
+
+        if not (0 <= idx < len(self)):
+            raise IndexError()
+
+        db_idx = np.searchsorted(self._cum_sizes, idx, 'right')
+        dataset = self.datasets[db_idx]
+        new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0)
+
+        if other is not None:
+            new_idx = (new_idx, other)
+        return dataset[new_idx]
+
+    @property
+    def _resolutions(self):
+        resolutions = self.datasets[0]._resolutions
+        for dataset in self.datasets[1:]:
+            assert tuple(dataset._resolutions) == tuple(resolutions)
+        return resolutions
diff --git a/dust3r/datasets/co3d.py b/dust3r/datasets/co3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fc94f9420d86372e643c00e7cddf85b3d1982c6
--- /dev/null
+++ b/dust3r/datasets/co3d.py
@@ -0,0 +1,146 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# Dataloader for preprocessed Co3d_v2
+# dataset at https://github.com/facebookresearch/co3d - Creative Commons Attribution-NonCommercial 4.0 International
+# See datasets_preprocess/preprocess_co3d.py
+# --------------------------------------------------------
+import os.path as osp
+import json
+import itertools
+from collections import deque
+
+import cv2
+import numpy as np
+
+from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
+from dust3r.utils.image import imread_cv2
+
+
+class Co3d(BaseStereoViewDataset):
+    def __init__(self, mask_bg=True, *args, ROOT, **kwargs):
+        self.ROOT = ROOT
+        super().__init__(*args, **kwargs)
+        assert mask_bg in (True, False, 'rand')
+        self.mask_bg = mask_bg
+
+        # load all scenes
+        with open(osp.join(self.ROOT, f'selected_seqs_{self.split}.json'), 'r') as f:
+            self.scenes = json.load(f)
+            self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0}
+            self.scenes = {(k, k2): v2 for k, v in self.scenes.items()
+                           for k2, v2 in v.items()}
+        self.scene_list = list(self.scenes.keys())
+
+        # for each scene, we have 100 images ==> 360 degrees (so 25 frames ~= 90 degrees)
+        # we prepare all combinations such that i-j = +/- [5, 10, .., 90] degrees
+        self.combinations = [(i, j)
+                             for i, j in itertools.combinations(range(100), 2)
+                             if 0 < abs(i-j) <= 30 and abs(i-j) % 5 == 0]
+
+        self.invalidate = {scene: {} for scene in self.scene_list}
+
+    def __len__(self):
+        return len(self.scene_list) * len(self.combinations)
+
+    def _get_views(self, idx, resolution, rng):
+        # choose a scene
+        obj, instance = self.scene_list[idx // len(self.combinations)]
+        image_pool = self.scenes[obj, instance]
+        im1_idx, im2_idx = self.combinations[idx % len(self.combinations)]
+
+        # add a bit of randomness
+        last = len(image_pool)-1
+
+        if resolution not in self.invalidate[obj, instance]:  # flag invalid images
+            self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))]
+
+        # decide now if we mask the bg
+        mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2))
+
+        views = []
+        imgs_idxs = [max(0, min(im_idx + rng.integers(-4, 5), last)) for im_idx in [im2_idx, im1_idx]]
+        imgs_idxs = deque(imgs_idxs)
+        while len(imgs_idxs) > 0:  # some images (few) have zero depth
+            im_idx = imgs_idxs.pop()
+
+            if self.invalidate[obj, instance][resolution][im_idx]:
+                # search for a valid image
+                random_direction = 2 * rng.choice(2) - 1
+                for offset in range(1, len(image_pool)):
+                    tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool)
+                    if not self.invalidate[obj, instance][resolution][tentative_im_idx]:
+                        im_idx = tentative_im_idx
+                        break
+
+            view_idx = image_pool[im_idx]
+
+            impath = osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg')
+
+            # load camera params
+            input_metadata = np.load(impath.replace('jpg', 'npz'))
+            camera_pose = input_metadata['camera_pose'].astype(np.float32)
+            intrinsics = input_metadata['camera_intrinsics'].astype(np.float32)
+
+            # load image and depth
+            rgb_image = imread_cv2(impath)
+            depthmap = imread_cv2(impath.replace('images', 'depths') + '.geometric.png', cv2.IMREAD_UNCHANGED)
+            depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth'])
+
+            if mask_bg:
+                # load object mask
+                maskpath = osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png')
+                maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32)
+                maskmap = (maskmap / 255.0) > 0.1
+
+                # update the depthmap with mask
+                depthmap *= maskmap
+
+            rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
+                rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath)
+
+            num_valid = (depthmap > 0.0).sum()
+            if num_valid == 0:
+                # problem, invalidate image and retry
+                self.invalidate[obj, instance][resolution][im_idx] = True
+                imgs_idxs.append(im_idx)
+                continue
+
+            views.append(dict(
+                img=rgb_image,
+                depthmap=depthmap,
+                camera_pose=camera_pose,
+                camera_intrinsics=intrinsics,
+                dataset='Co3d_v2',
+                label=osp.join(obj, instance),
+                instance=osp.split(impath)[1],
+            ))
+        return views
+
+
+if __name__ == "__main__":
+    from dust3r.datasets.base.base_stereo_view_dataset import view_name
+    from dust3r.viz import SceneViz, auto_cam_size
+    from dust3r.utils.image import rgb
+
+    dataset = Co3d(split='train', ROOT="data/co3d_subset_processed", resolution=224, aug_crop=16)
+
+    for idx in np.random.permutation(len(dataset)):
+        views = dataset[idx]
+        assert len(views) == 2
+        print(view_name(views[0]), view_name(views[1]))
+        viz = SceneViz()
+        poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]
+        cam_size = max(auto_cam_size(poses), 0.001)
+        for view_idx in [0, 1]:
+            pts3d = views[view_idx]['pts3d']
+            valid_mask = views[view_idx]['valid_mask']
+            colors = rgb(views[view_idx]['img'])
+            viz.add_pointcloud(pts3d, colors, valid_mask)
+            viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],
+                           focal=views[view_idx]['camera_intrinsics'][0, 0],
+                           color=(idx*255, (1 - idx)*255, 0),
+                           image=colors,
+                           cam_size=cam_size)
+        viz.show()
diff --git a/dust3r/datasets/utils/__init__.py b/dust3r/datasets/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e
--- /dev/null
+++ b/dust3r/datasets/utils/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
diff --git a/dust3r/datasets/utils/cropping.py b/dust3r/datasets/utils/cropping.py
new file mode 100644
index 0000000000000000000000000000000000000000..02b1915676f3deea24f57032f7588ff34cbfaeb9
--- /dev/null
+++ b/dust3r/datasets/utils/cropping.py
@@ -0,0 +1,119 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# croppping utilities
+# --------------------------------------------------------
+import PIL.Image
+import os
+os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
+import cv2  # noqa
+import numpy as np  # noqa
+from dust3r.utils.geometry import colmap_to_opencv_intrinsics, opencv_to_colmap_intrinsics  # noqa
+try:
+    lanczos = PIL.Image.Resampling.LANCZOS
+except AttributeError:
+    lanczos = PIL.Image.LANCZOS
+
+
+class ImageList:
+    """ Convenience class to aply the same operation to a whole set of images.
+    """
+
+    def __init__(self, images):
+        if not isinstance(images, (tuple, list, set)):
+            images = [images]
+        self.images = []
+        for image in images:
+            if not isinstance(image, PIL.Image.Image):
+                image = PIL.Image.fromarray(image)
+            self.images.append(image)
+
+    def __len__(self):
+        return len(self.images)
+
+    def to_pil(self):
+        return tuple(self.images) if len(self.images) > 1 else self.images[0]
+
+    @property
+    def size(self):
+        sizes = [im.size for im in self.images]
+        assert all(sizes[0] == s for s in sizes)
+        return sizes[0]
+
+    def resize(self, *args, **kwargs):
+        return ImageList(self._dispatch('resize', *args, **kwargs))
+
+    def crop(self, *args, **kwargs):
+        return ImageList(self._dispatch('crop', *args, **kwargs))
+
+    def _dispatch(self, func, *args, **kwargs):
+        return [getattr(im, func)(*args, **kwargs) for im in self.images]
+
+
+def rescale_image_depthmap(image, depthmap, camera_intrinsics, output_resolution):
+    """ Jointly rescale a (image, depthmap) 
+        so that (out_width, out_height) >= output_res
+    """
+    image = ImageList(image)
+    input_resolution = np.array(image.size)  # (W,H)
+    output_resolution = np.array(output_resolution)
+    if depthmap is not None:
+        # can also use this with masks instead of depthmaps
+        assert tuple(depthmap.shape[:2]) == image.size[::-1]
+    assert output_resolution.shape == (2,)
+    # define output resolution
+    scale_final = max(output_resolution / image.size) + 1e-8
+    output_resolution = np.floor(input_resolution * scale_final).astype(int)
+
+    # first rescale the image so that it contains the crop
+    image = image.resize(output_resolution, resample=lanczos)
+    if depthmap is not None:
+        depthmap = cv2.resize(depthmap, output_resolution, fx=scale_final,
+                              fy=scale_final, interpolation=cv2.INTER_NEAREST)
+
+    # no offset here; simple rescaling
+    camera_intrinsics = camera_matrix_of_crop(
+        camera_intrinsics, input_resolution, output_resolution, scaling=scale_final)
+
+    return image.to_pil(), depthmap, camera_intrinsics
+
+
+def camera_matrix_of_crop(input_camera_matrix, input_resolution, output_resolution, scaling=1, offset_factor=0.5, offset=None):
+    # Margins to offset the origin
+    margins = np.asarray(input_resolution) * scaling - output_resolution
+    assert np.all(margins >= 0.0)
+    if offset is None:
+        offset = offset_factor * margins
+
+    # Generate new camera parameters
+    output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
+    output_camera_matrix_colmap[:2, :] *= scaling
+    output_camera_matrix_colmap[:2, 2] -= offset
+    output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
+
+    return output_camera_matrix
+
+
+def crop_image_depthmap(image, depthmap, camera_intrinsics, crop_bbox):
+    """
+    Return a crop of the input view.
+    """
+    image = ImageList(image)
+    l, t, r, b = crop_bbox
+
+    image = image.crop((l, t, r, b))
+    depthmap = depthmap[t:b, l:r]
+
+    camera_intrinsics = camera_intrinsics.copy()
+    camera_intrinsics[0, 2] -= l
+    camera_intrinsics[1, 2] -= t
+
+    return image.to_pil(), depthmap, camera_intrinsics
+
+
+def bbox_from_intrinsics_in_out(input_camera_matrix, output_camera_matrix, output_resolution):
+    out_width, out_height = output_resolution
+    l, t = np.int32(np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2]))
+    crop_bbox = (l, t, l+out_width, t+out_height)
+    return crop_bbox
diff --git a/dust3r/datasets/utils/transforms.py b/dust3r/datasets/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb34f2f01d3f8f829ba71a7e03e181bf18f72c25
--- /dev/null
+++ b/dust3r/datasets/utils/transforms.py
@@ -0,0 +1,11 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# DUST3R default transforms
+# --------------------------------------------------------
+import torchvision.transforms as tvf
+from dust3r.utils.image import ImgNorm
+
+# define the standard image transforms
+ColorJitter = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
diff --git a/dust3r/heads/__init__.py b/dust3r/heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..53d0aa5610cae95f34f96bdb3ff9e835a2d6208e
--- /dev/null
+++ b/dust3r/heads/__init__.py
@@ -0,0 +1,19 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# head factory
+# --------------------------------------------------------
+from .linear_head import LinearPts3d
+from .dpt_head import create_dpt_head
+
+
+def head_factory(head_type, output_mode, net, has_conf=False):
+    """" build a prediction head for the decoder 
+    """
+    if head_type == 'linear' and output_mode == 'pts3d':
+        return LinearPts3d(net, has_conf)
+    elif head_type == 'dpt' and output_mode == 'pts3d':
+        return create_dpt_head(net, has_conf=has_conf)
+    else:
+        raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}")
diff --git a/dust3r/heads/__pycache__/__init__.cpython-310.pyc b/dust3r/heads/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0ef2d76fef4e34c8f9284b1079f25e880434d2b3
Binary files /dev/null and b/dust3r/heads/__pycache__/__init__.cpython-310.pyc differ
diff --git a/dust3r/heads/__pycache__/__init__.cpython-38.pyc b/dust3r/heads/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b657e9c74d6431a5693f6c639e2553dbd3c3568
Binary files /dev/null and b/dust3r/heads/__pycache__/__init__.cpython-38.pyc differ
diff --git a/dust3r/heads/__pycache__/__init__.cpython-39.pyc b/dust3r/heads/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ddfb3c493c8261a3b6c08938f59158040fecf40
Binary files /dev/null and b/dust3r/heads/__pycache__/__init__.cpython-39.pyc differ
diff --git a/dust3r/heads/__pycache__/dpt_head.cpython-310.pyc b/dust3r/heads/__pycache__/dpt_head.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a78e7a38180bebd8c889ffaa3860abc261752f3c
Binary files /dev/null and b/dust3r/heads/__pycache__/dpt_head.cpython-310.pyc differ
diff --git a/dust3r/heads/__pycache__/dpt_head.cpython-38.pyc b/dust3r/heads/__pycache__/dpt_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..934f8df70d76b338b1d56bd557a494d94f2b3f1a
Binary files /dev/null and b/dust3r/heads/__pycache__/dpt_head.cpython-38.pyc differ
diff --git a/dust3r/heads/__pycache__/linear_head.cpython-310.pyc b/dust3r/heads/__pycache__/linear_head.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e5fa3e11aaf16b0453c2c79babc49ff04f4630ae
Binary files /dev/null and b/dust3r/heads/__pycache__/linear_head.cpython-310.pyc differ
diff --git a/dust3r/heads/__pycache__/linear_head.cpython-38.pyc b/dust3r/heads/__pycache__/linear_head.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..054ea45c2d4b958cbb048c4794ab4b63a61d5111
Binary files /dev/null and b/dust3r/heads/__pycache__/linear_head.cpython-38.pyc differ
diff --git a/dust3r/heads/__pycache__/linear_head.cpython-39.pyc b/dust3r/heads/__pycache__/linear_head.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..171a27cf2e13bd957b71eaf63f28fbae0b01cb5a
Binary files /dev/null and b/dust3r/heads/__pycache__/linear_head.cpython-39.pyc differ
diff --git a/dust3r/heads/__pycache__/postprocess.cpython-310.pyc b/dust3r/heads/__pycache__/postprocess.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a04b32eb15aee535f6da5c955b110af76253b30c
Binary files /dev/null and b/dust3r/heads/__pycache__/postprocess.cpython-310.pyc differ
diff --git a/dust3r/heads/__pycache__/postprocess.cpython-38.pyc b/dust3r/heads/__pycache__/postprocess.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..768959709e9caca008efec7bac383037ebce737c
Binary files /dev/null and b/dust3r/heads/__pycache__/postprocess.cpython-38.pyc differ
diff --git a/dust3r/heads/dpt_head.py b/dust3r/heads/dpt_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..095890b95fe7b81fe90d04b77a9e20429381684b
--- /dev/null
+++ b/dust3r/heads/dpt_head.py
@@ -0,0 +1,115 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# dpt head implementation for DUST3R
+# Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ;
+# or if it takes as input the output at every layer, the attribute return_all_layers should be set to True
+# the forward function also takes as input a dictionnary img_info with key "height" and "width"
+# for PixelwiseTask, the output will be of dimension B x num_channels x H x W
+# --------------------------------------------------------
+from einops import rearrange
+from typing import List
+import torch
+import torch.nn as nn
+from dust3r.heads.postprocess import postprocess
+import dust3r.utils.path_to_croco  # noqa: F401
+from croco.models.dpt_block import DPTOutputAdapter  # noqa
+
+
+class DPTOutputAdapter_fix(DPTOutputAdapter):
+    """
+    Adapt croco's DPTOutputAdapter implementation for dust3rWithSam2:
+    remove duplicated weigths, and fix forward for dust3rWithSam2
+    """
+
+    def init(self, dim_tokens_enc=768):
+        super().init(dim_tokens_enc)
+        # these are duplicated weights
+        del self.act_1_postprocess
+        del self.act_2_postprocess
+        del self.act_3_postprocess
+        del self.act_4_postprocess
+
+    def forward(self, encoder_tokens: List[torch.Tensor], image_size=None):
+        assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
+        # H, W = input_info['image_size']
+        image_size = self.image_size if image_size is None else image_size
+        H, W = image_size
+        # Number of patches in height and width
+        N_H = H // (self.stride_level * self.P_H)
+        N_W = W // (self.stride_level * self.P_W)
+        # decoder一共有13个层,选择[0,6,9,12]层的结果到layers
+        # Hook decoder onto 4 layers from specified ViT layers
+        layers = [encoder_tokens[hook] for hook in self.hooks] # [0,6,9,12]
+
+        # Extract only task-relevant tokens and ignore global tokens.
+        layers = [self.adapt_tokens(l) for l in layers]
+
+        # Reshape tokens to spatial representation
+        layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
+        # 分别对每个层进行对应的卷积操作,将来自decoder各个block的输出结果映射成不同的尺寸,以便输入RefineNet
+        layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
+        # Project layers to chosen feature dim,再次分别对每个layers使用一个卷积,统一通道数为256
+        layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
+
+        # Fuse layers using refinement stages
+        path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]]
+        path_3 = self.scratch.refinenet3(path_4, layers[2])
+        path_2 = self.scratch.refinenet2(path_3, layers[1])
+        path_1 = self.scratch.refinenet1(path_2, layers[0])
+
+        # Output head
+        out = self.head(path_1)
+
+        return out
+
+
+class PixelwiseTaskWithDPT(nn.Module):
+    """ DPT module for dust3rWithSam2, can return 3D points + confidence for all pixels"""
+
+    def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
+                 output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs):
+        super(PixelwiseTaskWithDPT, self).__init__()
+        self.return_all_layers = True  # backbone needs to return all layers
+        self.postprocess = postprocess
+        self.depth_mode = depth_mode
+        self.conf_mode = conf_mode
+
+        assert n_cls_token == 0, "Not implemented"
+        dpt_args = dict(output_width_ratio=output_width_ratio,
+                        num_channels=num_channels,
+                        **kwargs)
+        if hooks_idx is not None:
+            dpt_args.update(hooks=hooks_idx)
+        self.dpt = DPTOutputAdapter_fix(**dpt_args)
+        dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens}
+        self.dpt.init(**dpt_init_args)
+
+    def forward(self, x, img_info): # Head
+        out = self.dpt(x, image_size=(img_info[0], img_info[1]))
+        if self.postprocess:
+            out = self.postprocess(out, self.depth_mode, self.conf_mode)
+        return out
+
+
+def create_dpt_head(net, has_conf=False):
+    """
+    return PixelwiseTaskWithDPT for given net params
+    """
+    assert net.dec_depth > 9
+    l2 = net.dec_depth
+    feature_dim = 256
+    last_dim = feature_dim//2
+    out_nchan = 3
+    ed = net.enc_embed_dim
+    dd = net.dec_embed_dim
+    return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf,
+                                feature_dim=feature_dim,
+                                last_dim=last_dim,
+                                hooks_idx=[0, l2*2//4, l2*3//4, l2],
+                                dim_tokens=[ed, dd, dd, dd],
+                                postprocess=postprocess,
+                                depth_mode=net.depth_mode,
+                                conf_mode=net.conf_mode,
+                                head_type='regression')
diff --git a/dust3r/heads/linear_head.py b/dust3r/heads/linear_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4ba75851e36eb5cbf9e3ca44f58282fbd69509d
--- /dev/null
+++ b/dust3r/heads/linear_head.py
@@ -0,0 +1,41 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# linear head implementation for DUST3R
+# --------------------------------------------------------
+import torch.nn as nn
+import torch.nn.functional as F
+from dust3r.heads.postprocess import postprocess
+
+
+class LinearPts3d (nn.Module):
+    """ 
+    Linear head for dust3rWithSam2
+    Each token outputs: - 16x16 3D points (+ confidence)
+    """
+
+    def __init__(self, net, has_conf=False):
+        super().__init__()
+        self.patch_size = net.patch_embed.patch_size[0]
+        self.depth_mode = net.depth_mode
+        self.conf_mode = net.conf_mode
+        self.has_conf = has_conf
+
+        self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2)
+
+    def setup(self, croconet):
+        pass
+
+    def forward(self, decout, img_shape):
+        H, W = img_shape
+        tokens = decout[-1]
+        B, S, D = tokens.shape
+
+        # extract 3D points
+        feat = self.proj(tokens)  # B,S,D
+        feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
+        feat = F.pixel_shuffle(feat, self.patch_size)  # B,3,H,W
+
+        # permute + norm depth
+        return postprocess(feat, self.depth_mode, self.conf_mode)
diff --git a/dust3r/heads/postprocess.py b/dust3r/heads/postprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc84628dda7558b584e928c4967b375e085dde49
--- /dev/null
+++ b/dust3r/heads/postprocess.py
@@ -0,0 +1,58 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# post process function for all heads: extract 3D points/confidence from output
+# --------------------------------------------------------
+import torch
+
+
+def postprocess(out, depth_mode, conf_mode):
+    """
+    extract 3D points/confidence from prediction head output 
+    """ # out的通道数为4,即分别表示三维点云的xyz坐标值和conf置信度
+    fmap = out.permute(0, 2, 3, 1)  # B=1,H,W,3
+    res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode))
+
+    if conf_mode is not None:
+        res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode)
+    return res
+
+
+def reg_dense_depth(xyz, mode):
+    """
+    extract 3D points from prediction head output
+    """
+    mode, vmin, vmax = mode
+
+    no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
+    assert no_bounds
+
+    if mode == 'linear':
+        if no_bounds:
+            return xyz  # [-inf, +inf]
+        return xyz.clip(min=vmin, max=vmax)
+
+    # distance to origin
+    d = xyz.norm(dim=-1, keepdim=True) # 对channel维度,即对x,y,z三个坐标值求第二范式
+    xyz = xyz / d.clip(min=1e-8) # 除以上面的norm,即归一化
+
+    if mode == 'square':
+        return xyz * d.square()
+
+    if mode == 'exp':
+        return xyz * torch.expm1(d)
+
+    raise ValueError(f'bad {mode=}')
+
+
+def reg_dense_conf(x, mode):
+    """
+    extract confidence from prediction head output
+    """
+    mode, vmin, vmax = mode
+    if mode == 'exp':
+        return vmin + x.exp().clip(max=vmax-vmin)
+    if mode == 'sigmoid':
+        return (vmax - vmin) * torch.sigmoid(x) + vmin
+    raise ValueError(f'bad {mode=}')
diff --git a/dust3r/image_pairs.py b/dust3r/image_pairs.py
new file mode 100644
index 0000000000000000000000000000000000000000..9251dc822b6b4b11bb9149dfd256ee1e66947562
--- /dev/null
+++ b/dust3r/image_pairs.py
@@ -0,0 +1,83 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# utilities needed to load image pairs
+# --------------------------------------------------------
+import numpy as np
+import torch
+
+
+def make_pairs(imgs, scene_graph='complete', prefilter=None, symmetrize=True):
+    pairs = []
+
+    if scene_graph == 'complete':  # complete graph
+        for i in range(len(imgs)):
+            for j in range(i):
+                pairs.append((imgs[i], imgs[j]))
+
+    elif scene_graph.startswith('swin'):
+        winsize = int(scene_graph.split('-')[1]) if '-' in scene_graph else 3
+        for i in range(len(imgs)):
+            for j in range(winsize):
+                idx = (i + j) % len(imgs)  # explicit loop closure
+                pairs.append((imgs[i], imgs[idx]))
+
+    elif scene_graph.startswith('oneref'):
+        refid = int(scene_graph.split('-')[1]) if '-' in scene_graph else 0
+        for j in range(len(imgs)):
+            if j != refid:
+                pairs.append((imgs[refid], imgs[j]))
+
+    elif scene_graph == 'pairs':
+        assert len(imgs) % 2 == 0
+        for i in range(0, len(imgs), 2):
+            pairs.append((imgs[i], imgs[i+1]))
+
+    if symmetrize:
+        pairs += [(img2, img1) for img1, img2 in pairs]
+
+    # now, remove edges
+    if isinstance(prefilter, str) and prefilter.startswith('seq'):
+        pairs = filter_pairs_seq(pairs, int(prefilter[3:]))
+
+    if isinstance(prefilter, str) and prefilter.startswith('cyc'):
+        pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True)
+
+    return pairs
+
+
+def sel(x, kept):
+    if isinstance(x, dict):
+        return {k: sel(v, kept) for k, v in x.items()}
+    if isinstance(x, (torch.Tensor, np.ndarray)):
+        return x[kept]
+    if isinstance(x, (tuple, list)):
+        return type(x)([x[k] for k in kept])
+
+
+def _filter_edges_seq(edges, seq_dis_thr, cyclic=False):
+    # number of images
+    n = max(max(e) for e in edges)+1
+
+    kept = []
+    for e, (i, j) in enumerate(edges):
+        dis = abs(i-j)
+        if cyclic:
+            dis = min(dis, abs(i+n-j), abs(i-n-j))
+        if dis <= seq_dis_thr:
+            kept.append(e)
+    return kept
+
+
+def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False):
+    edges = [(img1['idx'], img2['idx']) for img1, img2 in pairs]
+    kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
+    return [pairs[i] for i in kept]
+
+
+def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False):
+    edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]
+    kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
+    print(f'>> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges')
+    return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept)
diff --git a/dust3r/inference.py b/dust3r/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..294d3593e03cdf88b154ad3d7dcef6be9c17acb0
--- /dev/null
+++ b/dust3r/inference.py
@@ -0,0 +1,165 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# utilities needed for the inference
+# --------------------------------------------------------
+import tqdm
+import torch
+from dust3r.utils.device import to_cpu, collate_with_cat
+from dust3r.model import AsymmetricCroCo3DStereo, inf  # noqa: F401, needed when loading the model
+from dust3r.utils.misc import invalid_to_nans
+from dust3r.utils.geometry import depthmap_to_pts3d, geotrf
+
+
+def load_model(model_path, device):
+    print('... loading model from', model_path)
+    ckpt = torch.load(model_path, map_location='cpu')
+    args = ckpt['args'].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
+    if 'landscape_only' not in args:
+        args = args[:-1] + ', landscape_only=False)'
+    else:
+        args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False')
+    assert "landscape_only=False" in args
+    print(f"instantiating : {args}")
+    net = eval(args)
+    print(net.load_state_dict(ckpt['model'], strict=False))
+    return net.to(device)
+
+
+def _interleave_imgs(img1, img2):
+    res = {}
+    for key, value1 in img1.items():
+        value2 = img2[key]
+        if isinstance(value1, torch.Tensor):
+            value = torch.stack((value1, value2), dim=1).flatten(0, 1)
+        else:
+            value = [x for pair in zip(value1, value2) for x in pair]
+        res[key] = value
+    return res
+
+
+def make_batch_symmetric(batch):
+    view1, view2 = batch
+    view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1))
+    return view1, view2
+
+
+def loss_of_one_batch(batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None):
+    view1, view2 = batch # 输入模型的两张图片
+    for view in batch: # 将输入的图片放到GPU上
+        for name in 'img pts3d valid_mask camera_pose camera_intrinsics F_matrix corres'.split():  # pseudo_focal
+            if name not in view:
+                continue
+            view[name] = view[name].to(device, non_blocking=True) # 放到GPU上
+
+    if symmetrize_batch:
+        view1, view2 = make_batch_symmetric(batch)
+
+    with torch.cuda.amp.autocast(enabled=bool(use_amp)):
+        pred1, pred2 = model(view1, view2) # model:AsymmetricCroCo3DStereo
+
+        # loss is supposed to be symmetric
+        with torch.cuda.amp.autocast(enabled=False):# loss = None
+            loss = criterion(view1, view2, pred1, pred2) if criterion is not None else None
+
+    result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss) #这里loss为None
+    return result[ret] if ret else result
+
+
+@torch.no_grad()
+def inference(pairs, model, device, batch_size=8):
+    print(f'>> Inference with model on {len(pairs)} image pairs') # 所有照片两两成一对
+    result = []
+
+    # first, check if all images have the same size
+    multiple_shapes = not (check_if_same_size(pairs))
+    if multiple_shapes:  # force bs=1
+        batch_size = 1
+
+    for i in tqdm.trange(0, len(pairs), batch_size): # 将所有的pairs依次输入模型
+        res = loss_of_one_batch(collate_with_cat(pairs[i:i+batch_size]), model, None, device)
+        result.append(to_cpu(res))
+
+    result = collate_with_cat(result, lists=multiple_shapes) # view1、view2分别表示输入模型的两张图片
+
+    torch.cuda.empty_cache()
+    return result
+
+
+def check_if_same_size(pairs):
+    shapes1 = [img1['img'].shape[-2:] for img1, img2 in pairs]
+    shapes2 = [img2['img'].shape[-2:] for img1, img2 in pairs]
+    return all(shapes1[0] == s for s in shapes1) and all(shapes2[0] == s for s in shapes2)
+
+
+def get_pred_pts3d(gt, pred, use_pose=False):
+    if 'depth' in pred and 'pseudo_focal' in pred:
+        try:
+            pp = gt['camera_intrinsics'][..., :2, 2]
+        except KeyError:
+            pp = None
+        pts3d = depthmap_to_pts3d(**pred, pp=pp)
+
+    elif 'pts3d' in pred:
+        # pts3d from my camera
+        pts3d = pred['pts3d']
+
+    elif 'pts3d_in_other_view' in pred:
+        # pts3d from the other camera, already transformed
+        assert use_pose is True
+        return pred['pts3d_in_other_view']  # return!
+
+    if use_pose:
+        camera_pose = pred.get('camera_pose')
+        assert camera_pose is not None
+        pts3d = geotrf(camera_pose, pts3d)
+
+    return pts3d
+
+
+def find_opt_scaling(gt_pts1, gt_pts2, pr_pts1, pr_pts2=None, fit_mode='weiszfeld_stop_grad', valid1=None, valid2=None):
+    assert gt_pts1.ndim == pr_pts1.ndim == 4
+    assert gt_pts1.shape == pr_pts1.shape
+    if gt_pts2 is not None:
+        assert gt_pts2.ndim == pr_pts2.ndim == 4
+        assert gt_pts2.shape == pr_pts2.shape
+
+    # concat the pointcloud
+    nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2)
+    nan_gt_pts2 = invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None
+
+    pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2)
+    pr_pts2 = invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None
+
+    all_gt = torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1) if gt_pts2 is not None else nan_gt_pts1
+    all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1
+
+    dot_gt_pr = (all_pr * all_gt).sum(dim=-1)
+    dot_gt_gt = all_gt.square().sum(dim=-1)
+
+    if fit_mode.startswith('avg'):
+        # scaling = (all_pr / all_gt).view(B, -1).mean(dim=1)
+        scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
+    elif fit_mode.startswith('median'):
+        scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values
+    elif fit_mode.startswith('weiszfeld'):
+        # init scaling with l2 closed form
+        scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
+        # iterative re-weighted least-squares
+        for iter in range(10):
+            # re-weighting by inverse of distance
+            dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1)
+            # print(dis.nanmean(-1))
+            w = dis.clip_(min=1e-8).reciprocal()
+            # update the scaling with the new weights
+            scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1)
+    else:
+        raise ValueError(f'bad {fit_mode=}')
+
+    if fit_mode.endswith('stop_grad'):
+        scaling = scaling.detach()
+
+    scaling = scaling.clip(min=1e-3)
+    # assert scaling.isfinite().all(), bb()
+    return scaling
diff --git a/dust3r/losses.py b/dust3r/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d6e20fd3a30d6d498afdc13ec852ae984d05f7e
--- /dev/null
+++ b/dust3r/losses.py
@@ -0,0 +1,297 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# Implementation of DUSt3R training losses
+# --------------------------------------------------------
+from copy import copy, deepcopy
+import torch
+import torch.nn as nn
+
+from dust3r.inference import get_pred_pts3d, find_opt_scaling
+from dust3r.utils.geometry import inv, geotrf, normalize_pointcloud
+from dust3r.utils.geometry import get_joint_pointcloud_depth, get_joint_pointcloud_center_scale
+
+
+def Sum(*losses_and_masks):
+    loss, mask = losses_and_masks[0]
+    if loss.ndim > 0:
+        # we are actually returning the loss for every pixels
+        return losses_and_masks
+    else:
+        # we are returning the global loss
+        for loss2, mask2 in losses_and_masks[1:]:
+            loss = loss + loss2
+        return loss
+
+
+class LLoss (nn.Module):
+    """ L-norm loss
+    """
+
+    def __init__(self, reduction='mean'):
+        super().__init__()
+        self.reduction = reduction
+
+    def forward(self, a, b):
+        assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 3, f'Bad shape = {a.shape}'
+        dist = self.distance(a, b)
+        assert dist.ndim == a.ndim-1  # one dimension less
+        if self.reduction == 'none':
+            return dist
+        if self.reduction == 'sum':
+            return dist.sum()
+        if self.reduction == 'mean':
+            return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
+        raise ValueError(f'bad {self.reduction=} mode')
+
+    def distance(self, a, b):
+        raise NotImplementedError()
+
+
+class L21Loss (LLoss):
+    """ Euclidean distance between 3d points  """
+
+    def distance(self, a, b):
+        return torch.norm(a - b, dim=-1)  # normalized L2 distance
+
+
+L21 = L21Loss()
+
+
+class Criterion (nn.Module):
+    def __init__(self, criterion=None):
+        super().__init__()
+        assert isinstance(criterion, LLoss), f'{criterion} is not a proper criterion!'+bb()
+        self.criterion = copy(criterion)
+
+    def get_name(self):
+        return f'{type(self).__name__}({self.criterion})'
+
+    def with_reduction(self, mode):
+        res = loss = deepcopy(self)
+        while loss is not None:
+            assert isinstance(loss, Criterion)
+            loss.criterion.reduction = 'none'  # make it return the loss for each sample
+            loss = loss._loss2  # we assume loss is a Multiloss
+        return res
+
+
+class MultiLoss (nn.Module):
+    """ Easily combinable losses (also keep track of individual loss values):
+        loss = MyLoss1() + 0.1*MyLoss2()
+    Usage:
+        Inherit from this class and override get_name() and compute_loss()
+    """
+
+    def __init__(self):
+        super().__init__()
+        self._alpha = 1
+        self._loss2 = None
+
+    def compute_loss(self, *args, **kwargs):
+        raise NotImplementedError()
+
+    def get_name(self):
+        raise NotImplementedError()
+
+    def __mul__(self, alpha):
+        assert isinstance(alpha, (int, float))
+        res = copy(self)
+        res._alpha = alpha
+        return res
+    __rmul__ = __mul__  # same
+
+    def __add__(self, loss2):
+        assert isinstance(loss2, MultiLoss)
+        res = cur = copy(self)
+        # find the end of the chain
+        while cur._loss2 is not None:
+            cur = cur._loss2
+        cur._loss2 = loss2
+        return res
+
+    def __repr__(self):
+        name = self.get_name()
+        if self._alpha != 1:
+            name = f'{self._alpha:g}*{name}'
+        if self._loss2:
+            name = f'{name} + {self._loss2}'
+        return name
+
+    def forward(self, *args, **kwargs):
+        loss = self.compute_loss(*args, **kwargs)
+        if isinstance(loss, tuple):
+            loss, details = loss
+        elif loss.ndim == 0:
+            details = {self.get_name(): float(loss)}
+        else:
+            details = {}
+        loss = loss * self._alpha
+
+        if self._loss2:
+            loss2, details2 = self._loss2(*args, **kwargs)
+            loss = loss + loss2
+            details |= details2
+
+        return loss, details
+
+
+class Regr3D (Criterion, MultiLoss):
+    """ Ensure that all 3D points are correct.
+        Asymmetric loss: view1 is supposed to be the anchor.
+
+        P1 = RT1 @ D1
+        P2 = RT2 @ D2
+        loss1 = (I @ pred_D1) - (RT1^-1 @ RT1 @ D1)
+        loss2 = (RT21 @ pred_D2) - (RT1^-1 @ P2)
+              = (RT21 @ pred_D2) - (RT1^-1 @ RT2 @ D2)
+    """
+
+    def __init__(self, criterion, norm_mode='avg_dis', gt_scale=False):
+        super().__init__(criterion)
+        self.norm_mode = norm_mode
+        self.gt_scale = gt_scale
+
+    def get_all_pts3d(self, gt1, gt2, pred1, pred2, dist_clip=None):
+        # everything is normalized w.r.t. camera of view1
+        in_camera1 = inv(gt1['camera_pose'])
+        gt_pts1 = geotrf(in_camera1, gt1['pts3d'])  # B,H,W,3
+        gt_pts2 = geotrf(in_camera1, gt2['pts3d'])  # B,H,W,3
+
+        valid1 = gt1['valid_mask'].clone()
+        valid2 = gt2['valid_mask'].clone()
+
+        if dist_clip is not None:
+            # points that are too far-away == invalid
+            dis1 = gt_pts1.norm(dim=-1)  # (B, H, W)
+            dis2 = gt_pts2.norm(dim=-1)  # (B, H, W)
+            valid1 = valid1 & (dis1 <= dist_clip)
+            valid2 = valid2 & (dis2 <= dist_clip)
+
+        pr_pts1 = get_pred_pts3d(gt1, pred1, use_pose=False)
+        pr_pts2 = get_pred_pts3d(gt2, pred2, use_pose=True)
+
+        # normalize 3d points
+        if self.norm_mode:
+            pr_pts1, pr_pts2 = normalize_pointcloud(pr_pts1, pr_pts2, self.norm_mode, valid1, valid2)
+        if self.norm_mode and not self.gt_scale:
+            gt_pts1, gt_pts2 = normalize_pointcloud(gt_pts1, gt_pts2, self.norm_mode, valid1, valid2)
+
+        return gt_pts1, gt_pts2, pr_pts1, pr_pts2, valid1, valid2, {}
+
+    def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
+        gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \
+            self.get_all_pts3d(gt1, gt2, pred1, pred2, **kw)
+        # loss on img1 side
+        l1 = self.criterion(pred_pts1[mask1], gt_pts1[mask1])
+        # loss on gt2 side
+        l2 = self.criterion(pred_pts2[mask2], gt_pts2[mask2])
+        self_name = type(self).__name__
+        details = {self_name+'_pts3d_1': float(l1.mean()), self_name+'_pts3d_2': float(l2.mean())}
+        return Sum((l1, mask1), (l2, mask2)), (details | monitoring)
+
+
+class ConfLoss (MultiLoss):
+    """ Weighted regression by learned confidence.
+        Assuming the input pixel_loss is a pixel-level regression loss.
+
+    Principle:
+        high-confidence means high conf = 0.1 ==> conf_loss = x / 10 + alpha*log(10)
+        low  confidence means low  conf = 10  ==> conf_loss = x * 10 - alpha*log(10) 
+
+        alpha: hyperparameter
+    """
+
+    def __init__(self, pixel_loss, alpha=1):
+        super().__init__()
+        assert alpha > 0
+        self.alpha = alpha
+        self.pixel_loss = pixel_loss.with_reduction('none')
+
+    def get_name(self):
+        return f'ConfLoss({self.pixel_loss})'
+
+    def get_conf_log(self, x):
+        return x, torch.log(x)
+
+    def compute_loss(self, gt1, gt2, pred1, pred2, **kw):
+        # compute per-pixel loss
+        ((loss1, msk1), (loss2, msk2)), details = self.pixel_loss(gt1, gt2, pred1, pred2, **kw)
+        if loss1.numel() == 0:
+            print('NO VALID POINTS in img1', force=True)
+        if loss2.numel() == 0:
+            print('NO VALID POINTS in img2', force=True)
+
+        # weight by confidence
+        conf1, log_conf1 = self.get_conf_log(pred1['conf'][msk1])
+        conf2, log_conf2 = self.get_conf_log(pred2['conf'][msk2])
+        conf_loss1 = loss1 * conf1 - self.alpha * log_conf1
+        conf_loss2 = loss2 * conf2 - self.alpha * log_conf2
+
+        # average + nan protection (in case of no valid pixels at all)
+        conf_loss1 = conf_loss1.mean() if conf_loss1.numel() > 0 else 0
+        conf_loss2 = conf_loss2.mean() if conf_loss2.numel() > 0 else 0
+
+        return conf_loss1 + conf_loss2, dict(conf_loss_1=float(conf_loss1), conf_loss2=float(conf_loss2), **details)
+
+
+class Regr3D_ShiftInv (Regr3D):
+    """ Same than Regr3D but invariant to depth shift.
+    """
+
+    def get_all_pts3d(self, gt1, gt2, pred1, pred2):
+        # compute unnormalized points
+        gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = \
+            super().get_all_pts3d(gt1, gt2, pred1, pred2)
+
+        # compute median depth
+        gt_z1, gt_z2 = gt_pts1[..., 2], gt_pts2[..., 2]
+        pred_z1, pred_z2 = pred_pts1[..., 2], pred_pts2[..., 2]
+        gt_shift_z = get_joint_pointcloud_depth(gt_z1, gt_z2, mask1, mask2)[:, None, None]
+        pred_shift_z = get_joint_pointcloud_depth(pred_z1, pred_z2, mask1, mask2)[:, None, None]
+
+        # subtract the median depth
+        gt_z1 -= gt_shift_z
+        gt_z2 -= gt_shift_z
+        pred_z1 -= pred_shift_z
+        pred_z2 -= pred_shift_z
+
+        # monitoring = dict(monitoring, gt_shift_z=gt_shift_z.mean().detach(), pred_shift_z=pred_shift_z.mean().detach())
+        return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring
+
+
+class Regr3D_ScaleInv (Regr3D):
+    """ Same than Regr3D but invariant to depth shift.
+        if gt_scale == True: enforce the prediction to take the same scale than GT
+    """
+
+    def get_all_pts3d(self, gt1, gt2, pred1, pred2):
+        # compute depth-normalized points
+        gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring = super().get_all_pts3d(gt1, gt2, pred1, pred2)
+
+        # measure scene scale
+        _, gt_scale = get_joint_pointcloud_center_scale(gt_pts1, gt_pts2, mask1, mask2)
+        _, pred_scale = get_joint_pointcloud_center_scale(pred_pts1, pred_pts2, mask1, mask2)
+
+        # prevent predictions to be in a ridiculous range
+        pred_scale = pred_scale.clip(min=1e-3, max=1e3)
+
+        # subtract the median depth
+        if self.gt_scale:
+            pred_pts1 *= gt_scale / pred_scale
+            pred_pts2 *= gt_scale / pred_scale
+            # monitoring = dict(monitoring, pred_scale=(pred_scale/gt_scale).mean())
+        else:
+            gt_pts1 /= gt_scale
+            gt_pts2 /= gt_scale
+            pred_pts1 /= pred_scale
+            pred_pts2 /= pred_scale
+            # monitoring = dict(monitoring, gt_scale=gt_scale.mean(), pred_scale=pred_scale.mean().detach())
+
+        return gt_pts1, gt_pts2, pred_pts1, pred_pts2, mask1, mask2, monitoring
+
+
+class Regr3D_ScaleShiftInv (Regr3D_ScaleInv, Regr3D_ShiftInv):
+    # calls Regr3D_ShiftInv first, then Regr3D_ScaleInv
+    pass
diff --git a/dust3r/model.py b/dust3r/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..93099d522497d0d5bf2426477c68148f08b46b7b
--- /dev/null
+++ b/dust3r/model.py
@@ -0,0 +1,166 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# DUSt3R model class
+# --------------------------------------------------------
+from copy import deepcopy
+import torch
+
+from .utils.misc import fill_default_args, freeze_all_params, is_symmetrized, interleave, transpose_to_landscape
+from .heads import head_factory
+from dust3r.patch_embed import get_patch_embed
+
+import dust3r.utils.path_to_croco  # noqa: F401
+from croco.models.croco import CroCoNet  # noqa
+inf = float('inf')
+
+
+class AsymmetricCroCo3DStereo (CroCoNet):
+    """ Two siamese encoders, followed by two decoders.
+    The goal is to output 3d points directly, both images in view1's frame
+    (hence the asymmetry).   
+    """
+
+    def __init__(self,
+                 output_mode='pts3d',
+                 head_type='linear',
+                 depth_mode=('exp', -inf, inf),
+                 conf_mode=('exp', 1, inf),
+                 freeze='none',
+                 landscape_only=True,
+                 patch_embed_cls='PatchEmbedDust3R',  # PatchEmbedDust3R or ManyAR_PatchEmbed
+                 **croco_kwargs):
+        self.patch_embed_cls = patch_embed_cls
+        self.croco_args = fill_default_args(croco_kwargs, super().__init__)
+        super().__init__(**croco_kwargs)
+
+        # dust3rWithSam2 specific initialization
+        self.dec_blocks2 = deepcopy(self.dec_blocks)
+        self.set_downstream_head(output_mode, head_type, landscape_only, depth_mode, conf_mode, **croco_kwargs)
+        self.set_freeze(freeze)
+
+    def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
+        self.patch_embed = get_patch_embed(self.patch_embed_cls, img_size, patch_size, enc_embed_dim)
+
+    def load_state_dict(self, ckpt, **kw):
+        # duplicate all weights for the second decoder if not present
+        new_ckpt = dict(ckpt)
+        if not any(k.startswith('dec_blocks2') for k in ckpt):
+            for key, value in ckpt.items():
+                if key.startswith('dec_blocks'):
+                    new_ckpt[key.replace('dec_blocks', 'dec_blocks2')] = value
+        return super().load_state_dict(new_ckpt, **kw)
+
+    def set_freeze(self, freeze):  # this is for use by downstream models
+        self.freeze = freeze
+        to_be_frozen = {
+            'none':     [],
+            'mask':     [self.mask_token],
+            'encoder':  [self.mask_token, self.patch_embed, self.enc_blocks],
+        }
+        freeze_all_params(to_be_frozen[freeze])
+
+    def _set_prediction_head(self, *args, **kwargs):
+        """ No prediction head """
+        return
+
+    def set_downstream_head(self, output_mode, head_type, landscape_only, depth_mode, conf_mode, patch_size, img_size,
+                            **kw):
+        assert img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0, \
+            f'{img_size=} must be multiple of {patch_size=}'
+        self.output_mode = output_mode
+        self.head_type = head_type
+        self.depth_mode = depth_mode
+        self.conf_mode = conf_mode
+        # allocate heads
+        self.downstream_head1 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
+        self.downstream_head2 = head_factory(head_type, output_mode, self, has_conf=bool(conf_mode))
+        # magic wrapper
+        self.head1 = transpose_to_landscape(self.downstream_head1, activate=landscape_only)
+        self.head2 = transpose_to_landscape(self.downstream_head2, activate=landscape_only)
+
+    def _encode_image(self, image, true_shape): # image:输入的两张图片在batch维度上连接后的结果
+        # embed the image into patches  (x has size B x Npatches x C)
+        x, pos = self.patch_embed(image, true_shape=true_shape)  # 调用PatchEmbedDust3R,进行patch_embedding和位置编码
+
+        # add positional embedding without cls token
+        assert self.enc_pos_embed is None
+
+        # now apply the transformer encoder and normalization
+        for blk in self.enc_blocks: # 一共有24层block的encoder
+            x = blk(x, pos)
+
+        x = self.enc_norm(x) # LayerNorm
+        return x, pos, None
+
+    def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2):
+        if img1.shape[-2:] == img2.shape[-2:]:
+            out, pos, _ = self._encode_image(torch.cat((img1, img2), dim=0), # 将两张图片在batch维度上连接
+                                             torch.cat((true_shape1, true_shape2), dim=0))
+            out, out2 = out.chunk(2, dim=0)
+            pos, pos2 = pos.chunk(2, dim=0)
+        else: #******************************* 输入ViT encoder ************************
+            out, pos, _ = self._encode_image(img1, true_shape1)
+            out2, pos2, _ = self._encode_image(img2, true_shape2)
+        return out, out2, pos, pos2
+
+    def _encode_symmetrized(self, view1, view2):
+        img1 = view1['img']
+        img2 = view2['img']
+        B = img1.shape[0]
+        # Recover true_shape when available, otherwise assume that the img shape is the true one
+        shape1 = view1.get('true_shape', torch.tensor(img1.shape[-2:])[None].repeat(B, 1))
+        shape2 = view2.get('true_shape', torch.tensor(img2.shape[-2:])[None].repeat(B, 1))
+        # warning! maybe the images have different portrait/landscape orientations
+
+        if is_symmetrized(view1, view2):
+            # computing half of forward pass!'
+            feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1[::2], img2[::2], shape1[::2], shape2[::2])
+            feat1, feat2 = interleave(feat1, feat2)
+            pos1, pos2 = interleave(pos1, pos2)
+        else: #******************************* 输入ViT encoder ************************
+            feat1, feat2, pos1, pos2 = self._encode_image_pairs(img1, img2, shape1, shape2)
+
+        return (shape1, shape2), (feat1, feat2), (pos1, pos2)
+
+    def _decoder(self, f1, pos1, f2, pos2):
+        final_output = [(f1, f2)] # 来自encoder的两个编码 # 映射前的两个编码
+
+        # project to decoder dim # 一个Linear映射层
+        f1 = self.decoder_embed(f1) # Linear层,channel:1024->768
+        f2 = self.decoder_embed(f2)
+
+        final_output.append((f1, f2))                   # 映射后的两个编码
+        for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2): #dec_blocks2是由dec_blocks deepcopy过来的,所以是一样的
+            # img1 side,*final_output[-1][::+1]表示输入f1,f2
+            f1, _ = blk1(*final_output[-1][::+1], pos1, pos2)
+            # img2 side *final_output[-1][::-1]表示输入f2,f1
+            f2, _ = blk2(*final_output[-1][::-1], pos2, pos1)
+            # store the result
+            final_output.append((f1, f2))
+
+        # normalize last output
+        del final_output[1]  # duplicate with final_output[0],即删除 映射后的两个编码
+        final_output[-1] = tuple(map(self.dec_norm, final_output[-1]))
+        return zip(*final_output)
+
+    def _downstream_head(self, head_num, decout, img_shape):
+        B, S, D = decout[-1].shape
+        # img_shape = tuple(map(int, img_shape))
+        head = getattr(self, f'head{head_num}')
+        return head(decout, img_shape)
+
+    def forward(self, view1, view2):
+        # *****encode the two images --> B,S,D ** 输入ViT encoder ************************
+        (shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized(view1, view2)
+
+        # combine all ref images into object-centric representation **输入decoder*************
+        dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2)
+
+        with torch.cuda.amp.autocast(enabled=False): #Decoder的结果分别输入 Head1 和 Head2
+            res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1) # PixelwiseTaskWithDPT
+            res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2)
+
+        res2['pts3d_in_other_view'] = res2.pop('pts3d')  # predict view2's pts3d in view1's frame,即res2中的三维点云坐标是在view1的相机坐标系下的
+        return res1, res2
diff --git a/dust3r/optim_factory.py b/dust3r/optim_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b9c16e0e0fda3fd03c3def61abc1f354f75c584
--- /dev/null
+++ b/dust3r/optim_factory.py
@@ -0,0 +1,14 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# optimization functions
+# --------------------------------------------------------
+
+
+def adjust_learning_rate_by_lr(optimizer, lr):
+    for param_group in optimizer.param_groups:
+        if "lr_scale" in param_group:
+            param_group["lr"] = lr * param_group["lr_scale"]
+        else:
+            param_group["lr"] = lr
diff --git a/dust3r/patch_embed.py b/dust3r/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..23b6cb0cabf6b5351451e4cfb297e182c4128d1f
--- /dev/null
+++ b/dust3r/patch_embed.py
@@ -0,0 +1,70 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# PatchEmbed implementation for DUST3R,
+# in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio
+# --------------------------------------------------------
+import torch
+import dust3r.utils.path_to_croco  # noqa: F401
+from models.blocks import PatchEmbed  # noqa
+
+
+def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim):
+    assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed']
+    patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim)
+    return patch_embed
+
+
+class PatchEmbedDust3R(PatchEmbed):
+    def forward(self, x, **kw):
+        B, C, H, W = x.shape # 输入图片的尺寸得是16的倍数
+        assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
+        assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
+        x = self.proj(x) # 这是一个Conv2d卷积,即ViT的Patch_Embedding操作,输出编码的维度为1024,卷积核尺寸and步长都是16
+        pos = self.position_getter(B, x.size(2), x.size(3), x.device) # PositionGetter,位置编码
+        if self.flatten:
+            x = x.flatten(2).transpose(1, 2)  # BCHW -> BNC
+        x = self.norm(x) # nn.Identity()
+        return x, pos
+
+
+class ManyAR_PatchEmbed (PatchEmbed):
+    """ Handle images with non-square aspect ratio.
+        All images in the same batch have the same aspect ratio.
+        true_shape = [(height, width) ...] indicates the actual shape of each image.
+    """
+
+    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
+        self.embed_dim = embed_dim
+        super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten)
+
+    def forward(self, img, true_shape):
+        B, C, H, W = img.shape
+        assert W >= H, f'img should be in landscape mode, but got {W=} {H=}'
+        assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
+        assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
+        assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}"
+
+        # size expressed in tokens
+        W //= self.patch_size[0]
+        H //= self.patch_size[1]
+        n_tokens = H * W
+
+        height, width = true_shape.T
+        is_landscape = (width >= height)
+        is_portrait = ~is_landscape
+
+        # allocate result
+        x = img.new_zeros((B, n_tokens, self.embed_dim))
+        pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64)
+
+        # linear projection, transposed if necessary
+        x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float()
+        x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float()
+
+        pos[is_landscape] = self.position_getter(1, H, W, pos.device)
+        pos[is_portrait] = self.position_getter(1, W, H, pos.device)
+
+        x = self.norm(x)
+        return x, pos
diff --git a/dust3r/post_process.py b/dust3r/post_process.py
new file mode 100644
index 0000000000000000000000000000000000000000..e453f7e1c0e0fa2e2729cc3a2f57f9a0dc5ed025
--- /dev/null
+++ b/dust3r/post_process.py
@@ -0,0 +1,60 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# utilities for interpreting the DUST3R output
+# --------------------------------------------------------
+import numpy as np
+import torch
+from dust3r.utils.geometry import xy_grid
+
+# 估计焦距f,即论文中的《3.3.Downstream Applications-Recovering intrinsics.》章节的公式
+def estimate_focal_knowing_depth(pts3d, pp, focal_mode='median', min_focal=0.5, max_focal=3.5):
+    """ Reprojection method, for when the absolute depth is known:
+        1) estimate the camera focal using a robust estimator
+        2) reproject points onto true rays, minimizing a certain error
+    """
+    B, H, W, THREE = pts3d.shape
+    assert THREE == 3
+
+    # pixels即论文中的图像坐标系下的坐标(i`,j`):i` = i - W/2 , j` = j - H/2
+    pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(-1, 1, 2)  # B,HW,2
+    pts3d = pts3d.flatten(1, 2)  # (B, H*W, 3)
+
+    if focal_mode == 'median':
+        with torch.no_grad():
+            # direct estimation of focal
+            u, v = pixels.unbind(dim=-1)
+            x, y, z = pts3d.unbind(dim=-1)
+            fx_votes = (u * z) / x
+            fy_votes = (v * z) / y
+
+            # assume square pixels, hence same focal for X and Y
+            f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1)
+            focal = torch.nanmedian(f_votes, dim=-1).values
+
+    elif focal_mode == 'weiszfeld': # 参考链接:https://blog.csdn.net/qianlinjun/article/details/53852306
+        # init focal with l2 closed form
+        # we try to find focal = argmin Sum | pixel - focal * (x,y)/z|
+        xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0)  # 转齐次坐标,即x,y除以z坐标
+        # 1、初始化第一轮迭代时的focal
+        dot_xy_px = (xy_over_z * pixels).sum(dim=-1)
+        dot_xy_xy = xy_over_z.square().sum(dim=-1)
+
+        focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1)
+        # 2、基于weiszfeld算法迭代
+        # iterative re-weighted least-squares
+        for iter in range(10):
+            # re-weighting by inverse of distance
+            dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1) # norm:求第二范式
+            # print(dis.nanmean(-1))
+            w = dis.clip(min=1e-8).reciprocal() # 求倒数
+            # update the scaling with the new weights
+            focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1)
+    else:
+        raise ValueError(f'bad {focal_mode=}')
+
+    focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2))  # size / 1.1547005383792515
+    focal = focal.clip(min=min_focal*focal_base, max=max_focal*focal_base)
+    # print(focal)
+    return focal
diff --git a/dust3r/render_to_3d.py b/dust3r/render_to_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..63f742430cdb311cf2f543ff6e0fd6912342c27d
--- /dev/null
+++ b/dust3r/render_to_3d.py
@@ -0,0 +1,91 @@
+import os
+import torch
+import numpy as np
+import trimesh
+from scipy.spatial.transform import Rotation
+
+from dust3r.utils.device import to_numpy
+from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
+
+import matplotlib.pyplot as plt
+plt.ion()
+
+torch.backends.cuda.matmul.allow_tf32 = True  # for gpu >= Ampere and pytorch >= 1.12
+batch_size = 1
+
+
+# 将渲染的3D保存到outfile路径
+def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
+                                 cam_color=None, as_pointcloud=False, transparent_cams=False):
+    assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
+    pts3d = to_numpy(pts3d)
+    imgs = to_numpy(imgs)
+    focals = to_numpy(focals)
+    cams2world = to_numpy(cams2world)
+
+    scene = trimesh.Scene()
+
+    # full pointcloud
+    if as_pointcloud:
+        pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
+        col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
+        pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
+        scene.add_geometry(pct)
+    else:
+        meshes = []
+        for i in range(len(imgs)):
+            meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
+        mesh = trimesh.Trimesh(**cat_meshes(meshes))
+        scene.add_geometry(mesh)
+
+    # add each camera
+    for i, pose_c2w in enumerate(cams2world):
+        if isinstance(cam_color, list):
+            camera_edge_color = cam_color[i]
+        else:
+            camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
+        add_scene_cam(scene, pose_c2w, camera_edge_color,
+                      None if transparent_cams else imgs[i], focals[i],
+                      imsize=imgs[i].shape[1::-1], screen_width=cam_size)
+
+    rot = np.eye(4)
+    rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
+    scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
+    outfile = os.path.join(outdir, 'scene.glb')
+    print('(exporting 3D scene to', outfile, ')')
+    os.makedirs(outdir, exist_ok=True)
+    scene.export(file_obj=outfile)
+    return outfile
+
+
+
+def get_3D_model_from_scene(outdir, scene, sam2_masks, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
+                            clean_depth=False, transparent_cams=False, cam_size=0.05):
+    """
+    extract 3D_model (glb file) from a reconstructed scene
+    """
+    if scene is None:
+        return None
+    # post processes
+    if clean_depth:
+        scene = scene.clean_pointcloud()
+    if mask_sky:
+        scene = scene.mask_sky()
+
+    # get optimized values from scene
+    rgbimg = scene.imgs
+
+    focals = scene.get_focals().cpu()
+    cams2world = scene.get_im_poses().cpu()
+    # 3D pointcloud from depthmap, poses and intrinsics
+    pts3d = to_numpy(scene.get_pts3d())
+    scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
+    msk = to_numpy(scene.get_masks())
+
+    assert len(msk) == len(sam2_masks)
+    # 将sam2输出的mask 和 dust3r输出的置信度阈值筛选后的msk取交集
+    for i in range(len(sam2_masks)):
+        msk[i] = np.logical_and(msk[i], sam2_masks[i])
+
+    return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
+                                        transparent_cams=transparent_cams, cam_size=cam_size), msk # 置信度和SAM2 mask的交集
diff --git a/dust3r/utils/__init__.py b/dust3r/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a32692113d830ddc4af4e6ed608f222fbe062e6e
--- /dev/null
+++ b/dust3r/utils/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
diff --git a/dust3r/utils/__pycache__/__init__.cpython-310.pyc b/dust3r/utils/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bf6e2afbd1a015e26cb11835fe9c3d6a11c05c01
Binary files /dev/null and b/dust3r/utils/__pycache__/__init__.cpython-310.pyc differ
diff --git a/dust3r/utils/__pycache__/__init__.cpython-38.pyc b/dust3r/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..07a9d047b7ca97b0888f9507515d3d1ca20d6af6
Binary files /dev/null and b/dust3r/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/dust3r/utils/__pycache__/__init__.cpython-39.pyc b/dust3r/utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..33ae796e40503d0515ddf01c64c42982bf8b3468
Binary files /dev/null and b/dust3r/utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/dust3r/utils/__pycache__/device.cpython-310.pyc b/dust3r/utils/__pycache__/device.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..33e69f58cc58479e35d1d56e9b203c17a9a001d4
Binary files /dev/null and b/dust3r/utils/__pycache__/device.cpython-310.pyc differ
diff --git a/dust3r/utils/__pycache__/device.cpython-38.pyc b/dust3r/utils/__pycache__/device.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dbb8b155cfc7b55a489e19b9a7fbf1e334a840e0
Binary files /dev/null and b/dust3r/utils/__pycache__/device.cpython-38.pyc differ
diff --git a/dust3r/utils/__pycache__/device.cpython-39.pyc b/dust3r/utils/__pycache__/device.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d3d206ff8ca979e37ee838eec1d19ac8f8a5e66
Binary files /dev/null and b/dust3r/utils/__pycache__/device.cpython-39.pyc differ
diff --git a/dust3r/utils/__pycache__/geometry.cpython-310.pyc b/dust3r/utils/__pycache__/geometry.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a3c7f9ddeb45980a8bbd3bcdbf14c2ac7128baea
Binary files /dev/null and b/dust3r/utils/__pycache__/geometry.cpython-310.pyc differ
diff --git a/dust3r/utils/__pycache__/geometry.cpython-38.pyc b/dust3r/utils/__pycache__/geometry.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6bc6743772ebae02032bceb13607aa41da01cd28
Binary files /dev/null and b/dust3r/utils/__pycache__/geometry.cpython-38.pyc differ
diff --git a/dust3r/utils/__pycache__/image.cpython-310.pyc b/dust3r/utils/__pycache__/image.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4a744006e8d8f2d06f5e053f339e760f75b6e2a2
Binary files /dev/null and b/dust3r/utils/__pycache__/image.cpython-310.pyc differ
diff --git a/dust3r/utils/__pycache__/image.cpython-38.pyc b/dust3r/utils/__pycache__/image.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de4f0513e317e2eff5a15089eeb00f2f99d7593c
Binary files /dev/null and b/dust3r/utils/__pycache__/image.cpython-38.pyc differ
diff --git a/dust3r/utils/__pycache__/misc.cpython-310.pyc b/dust3r/utils/__pycache__/misc.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..16f56942d7ee7dc8636e877185d6095ff766aafa
Binary files /dev/null and b/dust3r/utils/__pycache__/misc.cpython-310.pyc differ
diff --git a/dust3r/utils/__pycache__/misc.cpython-38.pyc b/dust3r/utils/__pycache__/misc.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a3268307231ff6ace455c4fefe41ef601967e1d8
Binary files /dev/null and b/dust3r/utils/__pycache__/misc.cpython-38.pyc differ
diff --git a/dust3r/utils/__pycache__/misc.cpython-39.pyc b/dust3r/utils/__pycache__/misc.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f77dcf7d21464b9d50842ef0290f28186e12a4a6
Binary files /dev/null and b/dust3r/utils/__pycache__/misc.cpython-39.pyc differ
diff --git a/dust3r/utils/__pycache__/path_to_croco.cpython-310.pyc b/dust3r/utils/__pycache__/path_to_croco.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e8ec00cefee3dcff1ed3fa817313f5ddb7380b6e
Binary files /dev/null and b/dust3r/utils/__pycache__/path_to_croco.cpython-310.pyc differ
diff --git a/dust3r/utils/__pycache__/path_to_croco.cpython-38.pyc b/dust3r/utils/__pycache__/path_to_croco.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a613d9bc286cdbd21a0c691206d5845970a01603
Binary files /dev/null and b/dust3r/utils/__pycache__/path_to_croco.cpython-38.pyc differ
diff --git a/dust3r/utils/device.py b/dust3r/utils/device.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3b6a74dac05a2e1ba3a2b2f0faa8cea08ece745
--- /dev/null
+++ b/dust3r/utils/device.py
@@ -0,0 +1,76 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# utilitary functions for DUSt3R
+# --------------------------------------------------------
+import numpy as np
+import torch
+
+
+def todevice(batch, device, callback=None, non_blocking=False):
+    ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
+
+    batch: list, tuple, dict of tensors or other things
+    device: pytorch device or 'numpy'
+    callback: function that would be called on every sub-elements.
+    '''
+    if callback:
+        batch = callback(batch)
+
+    if isinstance(batch, dict):
+        return {k: todevice(v, device) for k, v in batch.items()}
+
+    if isinstance(batch, (tuple, list)):
+        return type(batch)(todevice(x, device) for x in batch)
+
+    x = batch
+    if device == 'numpy':
+        if isinstance(x, torch.Tensor):
+            x = x.detach().cpu().numpy()
+    elif x is not None:
+        if isinstance(x, np.ndarray):
+            x = torch.from_numpy(x)
+        if torch.is_tensor(x):
+            x = x.to(device, non_blocking=non_blocking)
+    return x
+
+
+to_device = todevice  # alias
+
+
+def to_numpy(x): return todevice(x, 'numpy')
+def to_cpu(x): return todevice(x, 'cpu')
+def to_cuda(x): return todevice(x, 'cuda')
+
+
+def collate_with_cat(whatever, lists=False):
+    if isinstance(whatever, dict):
+        return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()}
+
+    elif isinstance(whatever, (tuple, list)):
+        if len(whatever) == 0:
+            return whatever
+        elem = whatever[0]
+        T = type(whatever)
+
+        if elem is None:
+            return None
+        if isinstance(elem, (bool, float, int, str)):
+            return whatever
+        if isinstance(elem, tuple):
+            return T(collate_with_cat(x, lists=lists) for x in zip(*whatever))
+        if isinstance(elem, dict):
+            return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem}
+
+        if isinstance(elem, torch.Tensor):
+            return listify(whatever) if lists else torch.cat(whatever)
+        if isinstance(elem, np.ndarray):
+            return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever])
+
+        # otherwise, we just chain lists
+        return sum(whatever, T())
+
+
+def listify(elems):
+    return [x for e in elems for x in e]
diff --git a/dust3r/utils/geometry.py b/dust3r/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..648a72ec6498c481c357b732c1ef389e83c7422f
--- /dev/null
+++ b/dust3r/utils/geometry.py
@@ -0,0 +1,361 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# geometry utilitary functions
+# --------------------------------------------------------
+import torch
+import numpy as np
+from scipy.spatial import cKDTree as KDTree
+
+from dust3r.utils.misc import invalid_to_zeros, invalid_to_nans
+from dust3r.utils.device import to_numpy
+
+
+def xy_grid(W, H, device=None, origin=(0, 0), unsqueeze=None, cat_dim=-1, homogeneous=False, **arange_kw):
+    """ Output a (H,W,2) array of int32 
+        with output[j,i,0] = i + origin[0]
+             output[j,i,1] = j + origin[1]
+    """
+    if device is None:
+        # numpy
+        arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
+    else:
+        # torch
+        arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)
+        meshgrid, stack = torch.meshgrid, torch.stack
+        ones = lambda *a: torch.ones(*a, device=device)
+
+    tw, th = [arange(o, o+s, **arange_kw) for s, o in zip((W, H), origin)]
+    grid = meshgrid(tw, th, indexing='xy')
+    if homogeneous:
+        grid = grid + (ones((H, W)),)
+    if unsqueeze is not None:
+        grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
+    if cat_dim is not None:
+        grid = stack(grid, cat_dim)
+    return grid
+
+
+def geotrf(Trf, pts, ncol=None, norm=False):
+    """ Apply a geometric transformation to a list of 3-D points.
+
+    H: 3x3 or 4x4 projection matrix (typically a Homography)
+    p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
+
+    ncol: int. number of columns of the result (2 or 3)
+    norm: float. if != 0, the resut is projected on the z=norm plane.
+
+    Returns an array of projected 2d points.
+    """
+    assert Trf.ndim >= 2
+    if isinstance(Trf, np.ndarray):
+        pts = np.asarray(pts)
+    elif isinstance(Trf, torch.Tensor):
+        pts = torch.as_tensor(pts, dtype=Trf.dtype)
+
+    # adapt shape if necessary
+    output_reshape = pts.shape[:-1]
+    ncol = ncol or pts.shape[-1]
+
+    # optimized code
+    if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
+            Trf.ndim == 3 and pts.ndim == 4):
+        d = pts.shape[3]
+        if Trf.shape[-1] == d:
+            pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
+        elif Trf.shape[-1] == d+1:
+            pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
+        else:
+            raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
+    else:
+        if Trf.ndim >= 3:
+            n = Trf.ndim-2
+            assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
+            Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
+
+            if pts.ndim > Trf.ndim:
+                # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
+                pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
+            elif pts.ndim == 2:
+                # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
+                pts = pts[:, None, :]
+
+        if pts.shape[-1]+1 == Trf.shape[-1]:
+            Trf = Trf.swapaxes(-1, -2)  # transpose Trf
+            pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
+        elif pts.shape[-1] == Trf.shape[-1]:
+            Trf = Trf.swapaxes(-1, -2)  # transpose Trf
+            pts = pts @ Trf
+        else:
+            pts = Trf @ pts.T
+            if pts.ndim >= 2:
+                pts = pts.swapaxes(-1, -2)
+
+    if norm:
+        pts = pts / pts[..., -1:]  # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
+        if norm != 1:
+            pts *= norm
+
+    res = pts[..., :ncol].reshape(*output_reshape, ncol)
+    return res
+
+
+def inv(mat):
+    """ Invert a torch or numpy matrix
+    """
+    if isinstance(mat, torch.Tensor):
+        return torch.linalg.inv(mat)
+    if isinstance(mat, np.ndarray):
+        return np.linalg.inv(mat)
+    raise ValueError(f'bad matrix type = {type(mat)}')
+
+
+def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):
+    """
+    Args:
+        - depthmap (BxHxW array):
+        - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W]
+    Returns:
+        pointmap of absolute coordinates (BxHxWx3 array)
+    """
+
+    if len(depth.shape) == 4:
+        B, H, W, n = depth.shape
+    else:
+        B, H, W = depth.shape
+        n = None
+
+    if len(pseudo_focal.shape) == 3:  # [B,H,W]
+        pseudo_focalx = pseudo_focaly = pseudo_focal
+    elif len(pseudo_focal.shape) == 4:  # [B,2,H,W] or [B,1,H,W]
+        pseudo_focalx = pseudo_focal[:, 0]
+        if pseudo_focal.shape[1] == 2:
+            pseudo_focaly = pseudo_focal[:, 1]
+        else:
+            pseudo_focaly = pseudo_focalx
+    else:
+        raise NotImplementedError("Error, unknown input focal shape format.")
+
+    assert pseudo_focalx.shape == depth.shape[:3]
+    assert pseudo_focaly.shape == depth.shape[:3]
+    grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None]
+
+    # set principal point
+    if pp is None:
+        grid_x = grid_x - (W-1)/2
+        grid_y = grid_y - (H-1)/2
+    else:
+        grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None]
+        grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None]
+
+    if n is None:
+        pts3d = torch.empty((B, H, W, 3), device=depth.device)
+        pts3d[..., 0] = depth * grid_x / pseudo_focalx
+        pts3d[..., 1] = depth * grid_y / pseudo_focaly
+        pts3d[..., 2] = depth
+    else:
+        pts3d = torch.empty((B, H, W, 3, n), device=depth.device)
+        pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None]
+        pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None]
+        pts3d[..., 2, :] = depth
+    return pts3d
+
+
+def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
+    """
+    Args:
+        - depthmap (HxW array):
+        - camera_intrinsics: a 3x3 matrix
+    Returns:
+        pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
+    """
+    camera_intrinsics = np.float32(camera_intrinsics)
+    H, W = depthmap.shape
+
+    # Compute 3D ray associated with each pixel
+    # Strong assumption: there are no skew terms
+    assert camera_intrinsics[0, 1] == 0.0
+    assert camera_intrinsics[1, 0] == 0.0
+    if pseudo_focal is None:
+        fu = camera_intrinsics[0, 0]
+        fv = camera_intrinsics[1, 1]
+    else:
+        assert pseudo_focal.shape == (H, W)
+        fu = fv = pseudo_focal
+    cu = camera_intrinsics[0, 2]
+    cv = camera_intrinsics[1, 2]
+
+    u, v = np.meshgrid(np.arange(W), np.arange(H))
+    z_cam = depthmap
+    x_cam = (u - cu) * z_cam / fu
+    y_cam = (v - cv) * z_cam / fv
+    X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
+
+    # Mask for valid coordinates
+    valid_mask = (depthmap > 0.0)
+    return X_cam, valid_mask
+
+
+def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, **kw):
+    """
+    Args:
+        - depthmap (HxW array):
+        - camera_intrinsics: a 3x3 matrix
+        - camera_pose: a 4x3 or 4x4 cam2world matrix
+    Returns:
+        pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels."""
+    X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
+
+    # R_cam2world = np.float32(camera_params["R_cam2world"])
+    # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
+    R_cam2world = camera_pose[:3, :3]
+    t_cam2world = camera_pose[:3, 3]
+
+    # Express in absolute coordinates (invalid depth values)
+    X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
+    return X_world, valid_mask
+
+
+def colmap_to_opencv_intrinsics(K):
+    """
+    Modify camera intrinsics to follow a different convention.
+    Coordinates of the center of the top-left pixels are by default:
+    - (0.5, 0.5) in Colmap
+    - (0,0) in OpenCV
+    """
+    K = K.copy()
+    K[0, 2] -= 0.5
+    K[1, 2] -= 0.5
+    return K
+
+
+def opencv_to_colmap_intrinsics(K):
+    """
+    Modify camera intrinsics to follow a different convention.
+    Coordinates of the center of the top-left pixels are by default:
+    - (0.5, 0.5) in Colmap
+    - (0,0) in OpenCV
+    """
+    K = K.copy()
+    K[0, 2] += 0.5
+    K[1, 2] += 0.5
+    return K
+
+
+def normalize_pointcloud(pts1, pts2, norm_mode='avg_dis', valid1=None, valid2=None):
+    """ renorm pointmaps pts1, pts2 with norm_mode
+    """
+    assert pts1.ndim >= 3 and pts1.shape[-1] == 3
+    assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3)
+    norm_mode, dis_mode = norm_mode.split('_')
+
+    if norm_mode == 'avg':
+        # gather all points together (joint normalization)
+        nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3)
+        nan_pts2, nnz2 = invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0)
+        all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
+
+        # compute distance to origin
+        all_dis = all_pts.norm(dim=-1)
+        if dis_mode == 'dis':
+            pass  # do nothing
+        elif dis_mode == 'log1p':
+            all_dis = torch.log1p(all_dis)
+        elif dis_mode == 'warp-log1p':
+            # actually warp input points before normalizing them
+            log_dis = torch.log1p(all_dis)
+            warp_factor = log_dis / all_dis.clip(min=1e-8)
+            H1, W1 = pts1.shape[1:-1]
+            pts1 = pts1 * warp_factor[:, :W1*H1].view(-1, H1, W1, 1)
+            if pts2 is not None:
+                H2, W2 = pts2.shape[1:-1]
+                pts2 = pts2 * warp_factor[:, W1*H1:].view(-1, H2, W2, 1)
+            all_dis = log_dis  # this is their true distance afterwards
+        else:
+            raise ValueError(f'bad {dis_mode=}')
+
+        norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8)
+    else:
+        # gather all points together (joint normalization)
+        nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3)
+        nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None
+        all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
+
+        # compute distance to origin
+        all_dis = all_pts.norm(dim=-1)
+
+        if norm_mode == 'avg':
+            norm_factor = all_dis.nanmean(dim=1)
+        elif norm_mode == 'median':
+            norm_factor = all_dis.nanmedian(dim=1).values.detach()
+        elif norm_mode == 'sqrt':
+            norm_factor = all_dis.sqrt().nanmean(dim=1)**2
+        else:
+            raise ValueError(f'bad {norm_mode=}')
+
+    norm_factor = norm_factor.clip(min=1e-8)
+    while norm_factor.ndim < pts1.ndim:
+        norm_factor.unsqueeze_(-1)
+
+    res = pts1 / norm_factor
+    if pts2 is not None:
+        res = (res, pts2 / norm_factor)
+    return res
+
+
+@torch.no_grad()
+def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5):
+    # set invalid points to NaN
+    _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1)
+    _z2 = invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) if z2 is not None else None
+    _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1
+
+    # compute median depth overall (ignoring nans)
+    if quantile == 0.5:
+        shift_z = torch.nanmedian(_z, dim=-1).values
+    else:
+        shift_z = torch.nanquantile(_z, quantile, dim=-1)
+    return shift_z  # (B,)
+
+
+@torch.no_grad()
+def get_joint_pointcloud_center_scale(pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True):
+    # set invalid points to NaN
+    _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3)
+    _pts2 = invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) if pts2 is not None else None
+    _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1
+
+    # compute median center
+    _center = torch.nanmedian(_pts, dim=1, keepdim=True).values  # (B,1,3)
+    if z_only:
+        _center[..., :2] = 0  # do not center X and Y
+
+    # compute median norm
+    _norm = ((_pts - _center) if center else _pts).norm(dim=-1)
+    scale = torch.nanmedian(_norm, dim=1).values
+    return _center[:, None, :, :], scale[:, None, None, None]
+
+
+def find_reciprocal_matches(P1, P2):
+    """
+    returns 3 values:
+    1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match
+    2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1
+    3 - reciprocal_in_P2.sum(): the number of matches
+    """
+    tree1 = KDTree(P1)
+    tree2 = KDTree(P2)
+
+    _, nn1_in_P2 = tree2.query(P1, workers=8)
+    _, nn2_in_P1 = tree1.query(P2, workers=8)
+
+    reciprocal_in_P1 = (nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2)))
+    reciprocal_in_P2 = (nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1)))
+    assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum()
+    return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum()
+
+
+def get_med_dist_between_poses(poses):
+    from scipy.spatial.distance import pdist
+    return np.median(pdist([to_numpy(p[:3, 3]) for p in poses]))
diff --git a/dust3r/utils/image.py b/dust3r/utils/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a69c4ea139c7528f7abad2ce2e22b1178159d83
--- /dev/null
+++ b/dust3r/utils/image.py
@@ -0,0 +1,148 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# utilitary functions about images (loading/converting...)
+# --------------------------------------------------------
+import os
+import torch
+import numpy as np
+import PIL.Image
+from PIL.ImageOps import exif_transpose
+import torchvision.transforms as tvf
+os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
+import cv2  # noqa
+
+try:
+    from pillow_heif import register_heif_opener  # noqa
+    register_heif_opener()
+    heif_support_enabled = True
+except ImportError:
+    heif_support_enabled = False
+
+ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
+
+
+def imread_cv2(path, options=cv2.IMREAD_COLOR):
+    """ Open an image or a depthmap with opencv-python.
+    """
+    if path.endswith(('.exr', 'EXR')):
+        options = cv2.IMREAD_ANYDEPTH
+    img = cv2.imread(path, options)
+    if img is None:
+        raise IOError(f'Could not load image={path} with {options=}')
+    if img.ndim == 3:
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+    return img
+
+
+def rgb(ftensor, true_shape=None):
+    if isinstance(ftensor, list):
+        return [rgb(x, true_shape=true_shape) for x in ftensor]
+    if isinstance(ftensor, torch.Tensor):
+        ftensor = ftensor.detach().cpu().numpy()  # H,W,3
+    if ftensor.ndim == 3 and ftensor.shape[0] == 3:
+        ftensor = ftensor.transpose(1, 2, 0)
+    elif ftensor.ndim == 4 and ftensor.shape[1] == 3:
+        ftensor = ftensor.transpose(0, 2, 3, 1)
+    if true_shape is not None:
+        H, W = true_shape
+        ftensor = ftensor[:H, :W]
+    if ftensor.dtype == np.uint8:
+        img = np.float32(ftensor) / 255
+    else:
+        img = (ftensor * 0.5) + 0.5
+    return img.clip(min=0, max=1)
+
+
+def _resize_pil_image(img, long_edge_size):
+    S = max(img.size)
+    if S > long_edge_size:
+        interp = PIL.Image.LANCZOS
+    elif S <= long_edge_size:
+        interp = PIL.Image.BICUBIC
+    new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size)
+    return img.resize(new_size, interp)
+
+def resize_images(image_list, size, square_ok=False):
+    """ open and convert all images in a list or folder to proper input format for DUSt3R
+    """
+    imgs = []
+    for image in image_list:
+        img = exif_transpose(image).convert('RGB')
+        W1, H1 = img.size
+        if size == 224:
+            # resize short side to 224 (then crop)
+            img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1)))
+        else:
+            # resize long side to 512
+            img = _resize_pil_image(img, size)
+        W, H = img.size
+        cx, cy = W//2, H//2
+        if size == 224:
+            half = min(cx, cy)
+            img = img.crop((cx-half, cy-half, cx+half, cy+half))
+        else:
+            halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8
+            if not (square_ok) and W == H:
+                halfh = 3*halfw/4
+            img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh))
+
+        W2, H2 = img.size
+        print(f' - resize image with resolution {W1}x{H1} --> {W2}x{H2}')
+        imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32(
+            [img.size[::-1]]), idx=len(imgs), instance=str(len(imgs))))
+    assert imgs, 'no images'
+    print(f' (resized {len(imgs)} images)')
+    return imgs
+
+def load_images(folder_or_list, size, square_ok=False):
+    """ open and convert all images in a list or folder to proper input format for DUSt3R
+    """
+    if isinstance(folder_or_list, str):
+        print(f'>> Loading images from {folder_or_list}')
+        root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
+
+    elif isinstance(folder_or_list, list):
+        print(f'>> Loading a list of {len(folder_or_list)} images')
+        root, folder_content = '', folder_or_list
+
+    else:
+        raise ValueError(f'bad {folder_or_list=} ({type(folder_or_list)})')
+
+    supported_images_extensions = ['.jpg', '.jpeg', '.png']
+    if heif_support_enabled:
+        supported_images_extensions += ['.heic', '.heif']
+    supported_images_extensions = tuple(supported_images_extensions)
+
+    imgs = []
+    for path in folder_content:
+        if not path.lower().endswith(supported_images_extensions):
+            continue
+        img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert('RGB')
+        W1, H1 = img.size
+        if size == 224:
+            # resize short side to 224 (then crop)
+            img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1)))
+        else:
+            # resize long side to 512
+            img = _resize_pil_image(img, size)
+        W, H = img.size
+        cx, cy = W//2, H//2
+        if size == 224:
+            half = min(cx, cy)
+            img = img.crop((cx-half, cy-half, cx+half, cy+half))
+        else:
+            halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8
+            if not (square_ok) and W == H:
+                halfh = 3*halfw/4
+            img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh))
+
+        W2, H2 = img.size
+        print(f' - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}')
+        imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32(
+            [img.size[::-1]]), idx=len(imgs), instance=str(len(imgs))))
+
+    assert imgs, 'no images foud at '+root
+    print(f' (Found {len(imgs)} images)')
+    return imgs
diff --git a/dust3r/utils/misc.py b/dust3r/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab9fd06a063c3eafbfafddc011064ebb8a3232a8
--- /dev/null
+++ b/dust3r/utils/misc.py
@@ -0,0 +1,121 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# utilitary functions for DUSt3R
+# --------------------------------------------------------
+import torch
+
+
+def fill_default_args(kwargs, func):
+    import inspect  # a bit hacky but it works reliably
+    signature = inspect.signature(func)
+
+    for k, v in signature.parameters.items():
+        if v.default is inspect.Parameter.empty:
+            continue
+        kwargs.setdefault(k, v.default)
+
+    return kwargs
+
+
+def freeze_all_params(modules):
+    for module in modules:
+        try:
+            for n, param in module.named_parameters():
+                param.requires_grad = False
+        except AttributeError:
+            # module is directly a parameter
+            module.requires_grad = False
+
+
+def is_symmetrized(gt1, gt2):
+    x = gt1['instance']
+    y = gt2['instance']
+    if len(x) == len(y) and len(x) == 1:
+        return False  # special case of batchsize 1
+    ok = True
+    for i in range(0, len(x), 2):
+        ok = ok and (x[i] == y[i+1]) and (x[i+1] == y[i])
+    return ok
+
+
+def flip(tensor):
+    """ flip so that tensor[0::2] <=> tensor[1::2] """
+    return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1)
+
+
+def interleave(tensor1, tensor2):
+    res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)
+    res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)
+    return res1, res2
+
+
+def transpose_to_landscape(head, activate=True):
+    """ Predict in the correct aspect-ratio,
+        then transpose the result in landscape 
+        and stack everything back together.
+    """
+    def wrapper_no(decout, true_shape):
+        B = len(true_shape)
+        assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical'
+        H, W = true_shape[0].cpu().tolist()
+        res = head(decout, (H, W))
+        return res
+
+    def wrapper_yes(decout, true_shape):
+        B = len(true_shape)
+        # by definition, the batch is in landscape mode so W >= H
+        H, W = int(true_shape.min()), int(true_shape.max())
+
+        height, width = true_shape.T
+        is_landscape = (width >= height)
+        is_portrait = ~is_landscape
+
+        # true_shape = true_shape.cpu()
+        if is_landscape.all():
+            return head(decout, (H, W))
+        if is_portrait.all():
+            return transposed(head(decout, (W, H)))
+
+        # batch is a mix of both portraint & landscape
+        def selout(ar): return [d[ar] for d in decout]
+        l_result = head(selout(is_landscape), (H, W))
+        p_result = transposed(head(selout(is_portrait),  (W, H)))
+
+        # allocate full result
+        result = {}
+        for k in l_result | p_result:
+            x = l_result[k].new(B, *l_result[k].shape[1:])
+            x[is_landscape] = l_result[k]
+            x[is_portrait] = p_result[k]
+            result[k] = x
+
+        return result
+
+    return wrapper_yes if activate else wrapper_no
+
+
+def transposed(dic):
+    return {k: v.swapaxes(1, 2) for k, v in dic.items()}
+
+
+def invalid_to_nans(arr, valid_mask, ndim=999):
+    if valid_mask is not None:
+        arr = arr.clone()
+        arr[~valid_mask] = float('nan')
+    if arr.ndim > ndim:
+        arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
+    return arr
+
+
+def invalid_to_zeros(arr, valid_mask, ndim=999):
+    if valid_mask is not None:
+        arr = arr.clone()
+        arr[~valid_mask] = 0
+        nnz = valid_mask.view(len(valid_mask), -1).sum(1)
+    else:
+        nnz = arr.numel() // len(arr) if len(arr) else 0  # number of point per image
+    if arr.ndim > ndim:
+        arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
+    return arr, nnz
diff --git a/dust3r/utils/path_to_croco.py b/dust3r/utils/path_to_croco.py
new file mode 100644
index 0000000000000000000000000000000000000000..39226ce6bc0e1993ba98a22096de32cb6fa916b4
--- /dev/null
+++ b/dust3r/utils/path_to_croco.py
@@ -0,0 +1,19 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# CroCo submodule import
+# --------------------------------------------------------
+
+import sys
+import os.path as path
+HERE_PATH = path.normpath(path.dirname(__file__))
+CROCO_REPO_PATH = path.normpath(path.join(HERE_PATH, '../../croco'))
+CROCO_MODELS_PATH = path.join(CROCO_REPO_PATH, 'models')
+# check the presence of models directory in repo to be sure its cloned
+if path.isdir(CROCO_MODELS_PATH):
+    # workaround for sibling import
+    sys.path.insert(0, CROCO_REPO_PATH)
+else:
+    raise ImportError(f"croco is not initialized, could not find: {CROCO_MODELS_PATH}.\n "
+                      "Did you forget to run 'git submodule update --init --recursive' ?")
diff --git a/dust3r/viz.py b/dust3r/viz.py
new file mode 100644
index 0000000000000000000000000000000000000000..a21f399accf6710816cc4a858d60849ccaad31e1
--- /dev/null
+++ b/dust3r/viz.py
@@ -0,0 +1,320 @@
+# Copyright (C) 2024-present Naver Corporation. All rights reserved.
+# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
+#
+# --------------------------------------------------------
+# Visualization utilities using trimesh
+# --------------------------------------------------------
+import PIL.Image
+import numpy as np
+from scipy.spatial.transform import Rotation
+import torch
+
+from dust3r.utils.geometry import geotrf, get_med_dist_between_poses
+from dust3r.utils.device import to_numpy
+from dust3r.utils.image import rgb
+
+try:
+    import trimesh
+except ImportError:
+    print('/!\\ module trimesh is not installed, cannot visualize results /!\\')
+
+
+def cat_3d(vecs):
+    if isinstance(vecs, (np.ndarray, torch.Tensor)):
+        vecs = [vecs]
+    return np.concatenate([p.reshape(-1, 3) for p in to_numpy(vecs)])
+
+
+def show_raw_pointcloud(pts3d, colors, point_size=2):
+    scene = trimesh.Scene()
+
+    pct = trimesh.PointCloud(cat_3d(pts3d), colors=cat_3d(colors))
+    scene.add_geometry(pct)
+
+    scene.show(line_settings={'point_size': point_size})
+
+
+def pts3d_to_trimesh(img, pts3d, valid=None):
+    H, W, THREE = img.shape
+    assert THREE == 3
+    assert img.shape == pts3d.shape
+
+    vertices = pts3d.reshape(-1, 3)
+
+    # make squares: each pixel == 2 triangles
+    idx = np.arange(len(vertices)).reshape(H, W)
+    idx1 = idx[:-1, :-1].ravel()  # top-left corner
+    idx2 = idx[:-1, +1:].ravel()  # right-left corner
+    idx3 = idx[+1:, :-1].ravel()  # bottom-left corner
+    idx4 = idx[+1:, +1:].ravel()  # bottom-right corner
+    faces = np.concatenate((
+        np.c_[idx1, idx2, idx3],
+        np.c_[idx3, idx2, idx1],  # same triangle, but backward (cheap solution to cancel face culling)
+        np.c_[idx2, idx3, idx4],
+        np.c_[idx4, idx3, idx2],  # same triangle, but backward (cheap solution to cancel face culling)
+    ), axis=0)
+
+    # prepare triangle colors
+    face_colors = np.concatenate((
+        img[:-1, :-1].reshape(-1, 3),
+        img[:-1, :-1].reshape(-1, 3),
+        img[+1:, +1:].reshape(-1, 3),
+        img[+1:, +1:].reshape(-1, 3)
+    ), axis=0)
+
+    # remove invalid faces
+    if valid is not None:
+        assert valid.shape == (H, W)
+        valid_idxs = valid.ravel()
+        valid_faces = valid_idxs[faces].all(axis=-1)
+        faces = faces[valid_faces]
+        face_colors = face_colors[valid_faces]
+
+    assert len(faces) == len(face_colors)
+    return dict(vertices=vertices, face_colors=face_colors, faces=faces)
+
+
+def cat_meshes(meshes):
+    vertices, faces, colors = zip(*[(m['vertices'], m['faces'], m['face_colors']) for m in meshes])
+    n_vertices = np.cumsum([0]+[len(v) for v in vertices])
+    for i in range(len(faces)):
+        faces[i][:] += n_vertices[i]
+
+    vertices = np.concatenate(vertices)
+    colors = np.concatenate(colors)
+    faces = np.concatenate(faces)
+    return dict(vertices=vertices, face_colors=colors, faces=faces)
+
+
+def show_duster_pairs(view1, view2, pred1, pred2):
+    import matplotlib.pyplot as pl
+    pl.ion()
+
+    for e in range(len(view1['instance'])):
+        i = view1['idx'][e]
+        j = view2['idx'][e]
+        img1 = rgb(view1['img'][e])
+        img2 = rgb(view2['img'][e])
+        conf1 = pred1['conf'][e].squeeze()
+        conf2 = pred2['conf'][e].squeeze()
+        score = conf1.mean()*conf2.mean()
+        print(f">> Showing pair #{e} {i}-{j} {score=:g}")
+        pl.clf()
+        pl.subplot(221).imshow(img1)
+        pl.subplot(223).imshow(img2)
+        pl.subplot(222).imshow(conf1, vmin=1, vmax=30)
+        pl.subplot(224).imshow(conf2, vmin=1, vmax=30)
+        pts1 = pred1['pts3d'][e]
+        pts2 = pred2['pts3d_in_other_view'][e]
+        pl.subplots_adjust(0, 0, 1, 1, 0, 0)
+        if input('show pointcloud? (y/n) ') == 'y':
+            show_raw_pointcloud(cat(pts1, pts2), cat(img1, img2), point_size=5)
+
+
+def auto_cam_size(im_poses):
+    return 0.1 * get_med_dist_between_poses(im_poses)
+
+
+class SceneViz:
+    def __init__(self):
+        self.scene = trimesh.Scene()
+
+    def add_pointcloud(self, pts3d, color, mask=None):
+        pts3d = to_numpy(pts3d)
+        mask = to_numpy(mask)
+        if mask is None:
+            mask = [slice(None)] * len(pts3d)
+        pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
+        pct = trimesh.PointCloud(pts.reshape(-1, 3))
+
+        if isinstance(color, (list, np.ndarray, torch.Tensor)):
+            color = to_numpy(color)
+            col = np.concatenate([p[m] for p, m in zip(color, mask)])
+            assert col.shape == pts.shape
+            pct.visual.vertex_colors = uint8(col.reshape(-1, 3))
+        else:
+            assert len(color) == 3
+            pct.visual.vertex_colors = np.broadcast_to(uint8(color), pts.shape)
+
+        self.scene.add_geometry(pct)
+        return self
+
+    def add_camera(self, pose_c2w, focal=None, color=(0, 0, 0), image=None, imsize=None, cam_size=0.03):
+        pose_c2w, focal, color, image = to_numpy((pose_c2w, focal, color, image))
+        add_scene_cam(self.scene, pose_c2w, color, image, focal, screen_width=cam_size)
+        return self
+
+    def add_cameras(self, poses, focals=None, images=None, imsizes=None, colors=None, **kw):
+        def get(arr, idx): return None if arr is None else arr[idx]
+        for i, pose_c2w in enumerate(poses):
+            self.add_camera(pose_c2w, get(focals, i), image=get(images, i),
+                            color=get(colors, i), imsize=get(imsizes, i), **kw)
+        return self
+
+    def show(self, point_size=2):
+        self.scene.show(line_settings={'point_size': point_size})
+
+
+def show_raw_pointcloud_with_cams(imgs, pts3d, mask, focals, cams2world,
+                                  point_size=2, cam_size=0.05, cam_color=None):
+    """ Visualization of a pointcloud with cameras
+        imgs = (N, H, W, 3) or N-size list of [(H,W,3), ...]
+        pts3d = (N, H, W, 3) or N-size list of [(H,W,3), ...]
+        focals = (N,) or N-size list of [focal, ...]
+        cams2world = (N,4,4) or N-size list of [(4,4), ...]
+    """
+    assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
+    pts3d = to_numpy(pts3d)
+    imgs = to_numpy(imgs)
+    focals = to_numpy(focals)
+    cams2world = to_numpy(cams2world)
+
+    scene = trimesh.Scene()
+
+    # full pointcloud
+    pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
+    col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
+    pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
+    scene.add_geometry(pct)
+
+    # add each camera
+    for i, pose_c2w in enumerate(cams2world):
+        if isinstance(cam_color, list):
+            camera_edge_color = cam_color[i]
+        else:
+            camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
+        add_scene_cam(scene, pose_c2w, camera_edge_color,
+                      imgs[i] if i < len(imgs) else None, focals[i], screen_width=cam_size)
+
+    scene.show(line_settings={'point_size': point_size})
+
+
+def add_scene_cam(scene, pose_c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.03):
+
+    if image is not None:
+        H, W, THREE = image.shape
+        assert THREE == 3
+        if image.dtype != np.uint8:
+            image = np.uint8(255*image)
+    elif imsize is not None:
+        W, H = imsize
+    elif focal is not None:
+        H = W = focal / 1.1
+    else:
+        H = W = 1
+
+    if focal is None:
+        focal = min(H, W) * 1.1  # default value
+    elif isinstance(focal, np.ndarray):
+        focal = focal[0]
+
+    # create fake camera
+    height = focal * screen_width / H
+    width = screen_width * 0.5**0.5
+    rot45 = np.eye(4)
+    rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix()
+    rot45[2, 3] = -height  # set the tip of the cone = optical center
+    aspect_ratio = np.eye(4)
+    aspect_ratio[0, 0] = W/H
+    transform = pose_c2w @ OPENGL @ aspect_ratio @ rot45
+    cam = trimesh.creation.cone(width, height, sections=4)  # , transform=transform)
+
+    # this is the image
+    if image is not None:
+        vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]])
+        faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]])
+        img = trimesh.Trimesh(vertices=vertices, faces=faces)
+        uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]])
+        img.visual = trimesh.visual.TextureVisuals(uv_coords, image=PIL.Image.fromarray(image))
+        scene.add_geometry(img)
+
+    # this is the camera mesh
+    rot2 = np.eye(4)
+    rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(2)).as_matrix()
+    vertices = np.r_[cam.vertices, 0.95*cam.vertices, geotrf(rot2, cam.vertices)]
+    vertices = geotrf(transform, vertices)
+    faces = []
+    for face in cam.faces:
+        if 0 in face:
+            continue
+        a, b, c = face
+        a2, b2, c2 = face + len(cam.vertices)
+        a3, b3, c3 = face + 2*len(cam.vertices)
+
+        # add 3 pseudo-edges
+        faces.append((a, b, b2))
+        faces.append((a, a2, c))
+        faces.append((c2, b, c))
+
+        faces.append((a, b, b3))
+        faces.append((a, a3, c))
+        faces.append((c3, b, c))
+
+    # no culling
+    faces += [(c, b, a) for a, b, c in faces]
+
+    cam = trimesh.Trimesh(vertices=vertices, faces=faces)
+    cam.visual.face_colors[:, :3] = edge_color
+    scene.add_geometry(cam)
+
+
+def cat(a, b):
+    return np.concatenate((a.reshape(-1, 3), b.reshape(-1, 3)))
+
+
+OPENGL = np.array([[1, 0, 0, 0],
+                   [0, -1, 0, 0],
+                   [0, 0, -1, 0],
+                   [0, 0, 0, 1]])
+
+
+CAM_COLORS = [(255, 0, 0), (0, 0, 255), (0, 255, 0), (255, 0, 255), (255, 204, 0), (0, 204, 204),
+              (128, 255, 255), (255, 128, 255), (255, 255, 128), (0, 0, 0), (128, 128, 128)]
+
+
+def uint8(colors):
+    if not isinstance(colors, np.ndarray):
+        colors = np.array(colors)
+    if np.issubdtype(colors.dtype, np.floating):
+        colors *= 255
+    assert 0 <= colors.min() and colors.max() < 256
+    return np.uint8(colors)
+
+
+def segment_sky(image):
+    import cv2
+    from scipy import ndimage
+
+    # Convert to HSV
+    image = to_numpy(image)
+    if np.issubdtype(image.dtype, np.floating):
+        image = np.uint8(255*image.clip(min=0, max=1))
+    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
+
+    # Define range for blue color and create mask
+    lower_blue = np.array([0, 0, 100])
+    upper_blue = np.array([30, 255, 255])
+    mask = cv2.inRange(hsv, lower_blue, upper_blue).view(bool)
+
+    # add luminous gray
+    mask |= (hsv[:, :, 1] < 10) & (hsv[:, :, 2] > 150)
+    mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180)
+    mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220)
+
+    # Morphological operations
+    kernel = np.ones((5, 5), np.uint8)
+    mask2 = ndimage.binary_opening(mask, structure=kernel)
+
+    # keep only largest CC
+    _, labels, stats, _ = cv2.connectedComponentsWithStats(mask2.view(np.uint8), connectivity=8)
+    cc_sizes = stats[1:, cv2.CC_STAT_AREA]
+    order = cc_sizes.argsort()[::-1]  # bigger first
+    i = 0
+    selection = []
+    while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2:
+        selection.append(1 + order[i])
+        i += 1
+    mask3 = np.in1d(labels, selection).reshape(labels.shape)
+
+    # Apply mask
+    return torch.from_numpy(mask3)