Spaces:
Runtime error
Runtime error
import os.path as osp | |
import numpy as np | |
import itertools | |
import os | |
import sys | |
sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) | |
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset | |
from dust3r.utils.image import imread_cv2 | |
class MegaDepth_Multi(BaseMultiViewDataset): | |
def __init__(self, *args, ROOT, **kwargs): | |
self.ROOT = ROOT | |
super().__init__(*args, **kwargs) | |
self._load_data(self.split) | |
self.is_metric = False | |
if self.split is None: | |
pass | |
elif self.split == "train": | |
self.select_scene(("0015", "0022"), opposite=True) | |
elif self.split == "val": | |
self.select_scene(("0015", "0022")) | |
else: | |
raise ValueError(f"bad {self.split=}") | |
def _load_data(self, split): | |
with np.load( | |
osp.join(self.ROOT, "megadepth_sets_64.npz"), allow_pickle=True | |
) as data: | |
self.all_scenes = data["scenes"] | |
self.all_images = data["images"] | |
self.sets = data["sets"] | |
def __len__(self): | |
return len(self.sets) | |
def get_image_num(self): | |
return len(self.all_images) | |
def get_stats(self): | |
return f"{len(self)} groups from {len(self.all_scenes)} scenes" | |
def select_scene(self, scene, *instances, opposite=False): | |
scenes = (scene,) if isinstance(scene, str) else tuple(scene) | |
scene_id = [s.startswith(scenes) for s in self.all_scenes] | |
assert any(scene_id), "no scene found" | |
valid = np.in1d(self.sets[:, 0], np.nonzero(scene_id)[0]) | |
if instances: | |
raise NotImplementedError("selecting instances not implemented") | |
if opposite: | |
valid = ~valid | |
assert valid.any() | |
self.sets = self.sets[valid] | |
def _get_views(self, idx, resolution, rng, num_views): | |
scene_id = self.sets[idx][0] | |
image_idxs = self.sets[idx][1:65] | |
replace = False if not self.allow_repeat else True | |
image_idxs = rng.choice(image_idxs, num_views, replace=replace) | |
scene, subscene = self.all_scenes[scene_id].split() | |
seq_path = osp.join(self.ROOT, scene, subscene) | |
views = [] | |
for im_id in image_idxs: | |
img = self.all_images[im_id] | |
try: | |
image = imread_cv2(osp.join(seq_path, img + ".jpg")) | |
depthmap = imread_cv2(osp.join(seq_path, img + ".exr")) | |
camera_params = np.load(osp.join(seq_path, img + ".npz")) | |
except Exception as e: | |
raise OSError(f"cannot load {img}, got exception {e}") | |
intrinsics = np.float32(camera_params["intrinsics"]) | |
camera_pose = np.float32(camera_params["cam2world"]) | |
image, depthmap, intrinsics = self._crop_resize_if_necessary( | |
image, depthmap, intrinsics, resolution, rng, info=(seq_path, img) | |
) | |
views.append( | |
dict( | |
img=image, | |
depthmap=depthmap, | |
camera_pose=camera_pose, # cam2world | |
camera_intrinsics=intrinsics, | |
dataset="MegaDepth", | |
label=osp.relpath(seq_path, self.ROOT), | |
is_metric=self.is_metric, | |
instance=img, | |
is_video=False, | |
quantile=np.array(0.96, dtype=np.float32), | |
img_mask=True, | |
ray_mask=False, | |
camera_only=False, | |
depth_only=False, | |
single_view=False, | |
reset=False, | |
) | |
) | |
assert len(views) == num_views | |
return views | |