liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
import os.path as osp
import numpy as np
import itertools
import os
import sys
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
from dust3r.utils.image import imread_cv2
class MegaDepth_Multi(BaseMultiViewDataset):
def __init__(self, *args, ROOT, **kwargs):
self.ROOT = ROOT
super().__init__(*args, **kwargs)
self._load_data(self.split)
self.is_metric = False
if self.split is None:
pass
elif self.split == "train":
self.select_scene(("0015", "0022"), opposite=True)
elif self.split == "val":
self.select_scene(("0015", "0022"))
else:
raise ValueError(f"bad {self.split=}")
def _load_data(self, split):
with np.load(
osp.join(self.ROOT, "megadepth_sets_64.npz"), allow_pickle=True
) as data:
self.all_scenes = data["scenes"]
self.all_images = data["images"]
self.sets = data["sets"]
def __len__(self):
return len(self.sets)
def get_image_num(self):
return len(self.all_images)
def get_stats(self):
return f"{len(self)} groups from {len(self.all_scenes)} scenes"
def select_scene(self, scene, *instances, opposite=False):
scenes = (scene,) if isinstance(scene, str) else tuple(scene)
scene_id = [s.startswith(scenes) for s in self.all_scenes]
assert any(scene_id), "no scene found"
valid = np.in1d(self.sets[:, 0], np.nonzero(scene_id)[0])
if instances:
raise NotImplementedError("selecting instances not implemented")
if opposite:
valid = ~valid
assert valid.any()
self.sets = self.sets[valid]
def _get_views(self, idx, resolution, rng, num_views):
scene_id = self.sets[idx][0]
image_idxs = self.sets[idx][1:65]
replace = False if not self.allow_repeat else True
image_idxs = rng.choice(image_idxs, num_views, replace=replace)
scene, subscene = self.all_scenes[scene_id].split()
seq_path = osp.join(self.ROOT, scene, subscene)
views = []
for im_id in image_idxs:
img = self.all_images[im_id]
try:
image = imread_cv2(osp.join(seq_path, img + ".jpg"))
depthmap = imread_cv2(osp.join(seq_path, img + ".exr"))
camera_params = np.load(osp.join(seq_path, img + ".npz"))
except Exception as e:
raise OSError(f"cannot load {img}, got exception {e}")
intrinsics = np.float32(camera_params["intrinsics"])
camera_pose = np.float32(camera_params["cam2world"])
image, depthmap, intrinsics = self._crop_resize_if_necessary(
image, depthmap, intrinsics, resolution, rng, info=(seq_path, img)
)
views.append(
dict(
img=image,
depthmap=depthmap,
camera_pose=camera_pose, # cam2world
camera_intrinsics=intrinsics,
dataset="MegaDepth",
label=osp.relpath(seq_path, self.ROOT),
is_metric=self.is_metric,
instance=img,
is_video=False,
quantile=np.array(0.96, dtype=np.float32),
img_mask=True,
ray_mask=False,
camera_only=False,
depth_only=False,
single_view=False,
reset=False,
)
)
assert len(views) == num_views
return views