import os.path as osp import numpy as np import cv2 import numpy as np import itertools import os import sys import pickle import h5py from tqdm import tqdm sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset from dust3r.utils.image import imread_cv2 class MapFree_Multi(BaseMultiViewDataset): def __init__(self, ROOT, *args, **kwargs): self.ROOT = ROOT self.video = True self.is_metric = True self.max_interval = 30 super().__init__(*args, **kwargs) self._load_data() def imgid2path(self, img_id, scene): first_seq_id, first_frame_id = img_id return os.path.join( self.ROOT, scene, f"dense{first_seq_id}", "rgb", f"frame_{first_frame_id:05d}.jpg", ) def path2imgid(self, subscene, filename): first_seq_id = int(subscene[5:]) first_frame_id = int(filename[6:-4]) return [first_seq_id, first_frame_id] def _load_data(self): cache_file = f"{self.ROOT}/cached_metadata_50_col_only.h5" if os.path.exists(cache_file): print(f"Loading cached metadata from {cache_file}") with h5py.File(cache_file, "r") as hf: self.scenes = list(map(lambda x: x.decode("utf-8"), hf["scenes"][:])) self.sceneids = hf["sceneids"][:] self.scope = hf["scope"][:] self.video_flags = hf["video_flags"][:] self.groups = hf["groups"][:] self.id_ranges = hf["id_ranges"][:] self.images = hf["images"][:] else: scene_dirs = sorted( [ d for d in os.listdir(self.ROOT) if os.path.isdir(os.path.join(self.ROOT, d)) ] ) scenes = [] sceneids = [] groups = [] scope = [] images = [] id_ranges = [] is_video = [] start = 0 j = 0 offset = 0 for scene in tqdm(scene_dirs): scenes.append(scene) # video sequences subscenes = sorted( [ d for d in os.listdir(os.path.join(self.ROOT, scene)) if d.startswith("dense") ] ) id_range_subscenes = [] for subscene in subscenes: rgb_paths = sorted( [ d for d in os.listdir( os.path.join(self.ROOT, scene, subscene, "rgb") ) if d.endswith(".jpg") ] ) assert ( len(rgb_paths) > 0 ), f"{os.path.join(self.ROOT, scene, subscene)} is empty." num_imgs = len(rgb_paths) images.extend( [self.path2imgid(subscene, rgb_path) for rgb_path in rgb_paths] ) id_range_subscenes.append((offset, offset + num_imgs)) offset += num_imgs # image collections metadata = pickle.load( open(os.path.join(self.ROOT, scene, "metadata.pkl"), "rb") ) ref_imgs = list(metadata.keys()) img_groups = [] for ref_img in ref_imgs: other_imgs = metadata[ref_img] if len(other_imgs) + 1 < self.num_views: continue group = [(*other_img[0], other_img[1]) for other_img in other_imgs] group.insert(0, (*ref_img, 1)) img_groups.append(np.array(group)) id_ranges.append(id_range_subscenes[ref_img[0]]) scope.append(start) start = start + len(group) num_groups = len(img_groups) sceneids.extend([j] * num_groups) groups.extend(img_groups) is_video.extend([False] * num_groups) j += 1 self.scenes = np.array(scenes) self.sceneids = np.array(sceneids) self.scope = np.array(scope) self.video_flags = np.array(is_video) self.groups = np.concatenate(groups, 0) self.id_ranges = np.array(id_ranges) self.images = np.array(images) data = dict( scenes=self.scenes, sceneids=self.sceneids, scope=self.scope, video_flags=self.video_flags, groups=self.groups, id_ranges=self.id_ranges, images=self.images, ) with h5py.File(cache_file, "w") as h5f: h5f.create_dataset( "scenes", data=data["scenes"].astype(object), dtype=h5py.string_dtype(encoding="utf-8"), compression="lzf", chunks=True, ) h5f.create_dataset( "sceneids", data=data["sceneids"], compression="lzf", chunks=True ) h5f.create_dataset( "scope", data=data["scope"], compression="lzf", chunks=True ) h5f.create_dataset( "video_flags", data=data["video_flags"], compression="lzf", chunks=True, ) h5f.create_dataset( "groups", data=data["groups"], compression="lzf", chunks=True ) h5f.create_dataset( "id_ranges", data=data["id_ranges"], compression="lzf", chunks=True ) h5f.create_dataset( "images", data=data["images"], compression="lzf", chunks=True ) def __len__(self): return len(self.scope) def get_image_num(self): return len(self.images) def get_stats(self): return f"{len(self)} groups of views" def _get_views(self, idx, resolution, rng, num_views): scene = self.scenes[self.sceneids[idx]] if rng.random() < 0.6: ids = np.arange(self.id_ranges[idx][0], self.id_ranges[idx][1]) cut_off = num_views if not self.allow_repeat else max(num_views // 3, 3) start_ids = ids[: len(ids) - cut_off + 1] start_id = rng.choice(start_ids) pos, ordered_video = self.get_seq_from_start_id( num_views, start_id, ids.tolist(), rng, max_interval=self.max_interval, video_prob=0.8, fix_interval_prob=0.5, block_shuffle=16, ) ids = np.array(ids)[pos] image_idxs = self.images[ids] else: ordered_video = False seq_start_index = self.scope[idx] seq_end_index = self.scope[idx + 1] if idx < len(self.scope) - 1 else None image_idxs = ( self.groups[seq_start_index:seq_end_index] if seq_end_index is not None else self.groups[seq_start_index:] ) image_idxs, overlap_scores = image_idxs[:, :2], image_idxs[:, 2] replace = ( True if self.allow_repeat or len(overlap_scores[overlap_scores > 0]) < num_views else False ) image_idxs = rng.choice( image_idxs, num_views, replace=replace, p=overlap_scores / np.sum(overlap_scores), ) image_idxs = image_idxs.astype(np.int64) views = [] for v, view_idx in enumerate(image_idxs): img_path = self.imgid2path(view_idx, scene) depth_path = img_path.replace("rgb", "depth").replace(".jpg", ".npy") cam_path = img_path.replace("rgb", "cam").replace(".jpg", ".npz") sky_mask_path = img_path.replace("rgb", "sky_mask") image = imread_cv2(img_path) depthmap = np.load(depth_path) camera_params = np.load(cam_path) sky_mask = cv2.imread(sky_mask_path, cv2.IMREAD_UNCHANGED) >= 127 intrinsics = camera_params["intrinsic"].astype(np.float32) camera_pose = camera_params["pose"].astype(np.float32) depthmap[sky_mask] = -1.0 depthmap[depthmap > 400.0] = 0.0 depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0) threshold = ( np.percentile(depthmap[depthmap > 0], 98) if depthmap[depthmap > 0].size > 0 else 0 ) depthmap[depthmap > threshold] = 0.0 image, depthmap, intrinsics = self._crop_resize_if_necessary( image, depthmap, intrinsics, resolution, rng, info=(img_path) ) # generate img mask and raymap mask img_mask, ray_mask = self.get_img_and_ray_masks( self.is_metric, v, rng, p=[0.75, 0.2, 0.05] ) views.append( dict( img=image, depthmap=depthmap, camera_pose=camera_pose, # cam2world camera_intrinsics=intrinsics, dataset="MapFree", label=img_path, is_metric=self.is_metric, instance=img_path, is_video=ordered_video, quantile=np.array(0.96, dtype=np.float32), img_mask=img_mask, ray_mask=ray_mask, camera_only=False, depth_only=False, single_view=False, reset=False, ) ) assert len(views) == num_views return views