Spaces:
Runtime error
Runtime error
import os.path as osp | |
import numpy as np | |
import cv2 | |
import numpy as np | |
import itertools | |
import os | |
import sys | |
import pickle | |
import h5py | |
from tqdm import tqdm | |
sys.path.append(osp.join(osp.dirname(__file__), "..", "..")) | |
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset | |
from dust3r.utils.image import imread_cv2 | |
class MapFree_Multi(BaseMultiViewDataset): | |
def __init__(self, ROOT, *args, **kwargs): | |
self.ROOT = ROOT | |
self.video = True | |
self.is_metric = True | |
self.max_interval = 30 | |
super().__init__(*args, **kwargs) | |
self._load_data() | |
def imgid2path(self, img_id, scene): | |
first_seq_id, first_frame_id = img_id | |
return os.path.join( | |
self.ROOT, | |
scene, | |
f"dense{first_seq_id}", | |
"rgb", | |
f"frame_{first_frame_id:05d}.jpg", | |
) | |
def path2imgid(self, subscene, filename): | |
first_seq_id = int(subscene[5:]) | |
first_frame_id = int(filename[6:-4]) | |
return [first_seq_id, first_frame_id] | |
def _load_data(self): | |
cache_file = f"{self.ROOT}/cached_metadata_50_col_only.h5" | |
if os.path.exists(cache_file): | |
print(f"Loading cached metadata from {cache_file}") | |
with h5py.File(cache_file, "r") as hf: | |
self.scenes = list(map(lambda x: x.decode("utf-8"), hf["scenes"][:])) | |
self.sceneids = hf["sceneids"][:] | |
self.scope = hf["scope"][:] | |
self.video_flags = hf["video_flags"][:] | |
self.groups = hf["groups"][:] | |
self.id_ranges = hf["id_ranges"][:] | |
self.images = hf["images"][:] | |
else: | |
scene_dirs = sorted( | |
[ | |
d | |
for d in os.listdir(self.ROOT) | |
if os.path.isdir(os.path.join(self.ROOT, d)) | |
] | |
) | |
scenes = [] | |
sceneids = [] | |
groups = [] | |
scope = [] | |
images = [] | |
id_ranges = [] | |
is_video = [] | |
start = 0 | |
j = 0 | |
offset = 0 | |
for scene in tqdm(scene_dirs): | |
scenes.append(scene) | |
# video sequences | |
subscenes = sorted( | |
[ | |
d | |
for d in os.listdir(os.path.join(self.ROOT, scene)) | |
if d.startswith("dense") | |
] | |
) | |
id_range_subscenes = [] | |
for subscene in subscenes: | |
rgb_paths = sorted( | |
[ | |
d | |
for d in os.listdir( | |
os.path.join(self.ROOT, scene, subscene, "rgb") | |
) | |
if d.endswith(".jpg") | |
] | |
) | |
assert ( | |
len(rgb_paths) > 0 | |
), f"{os.path.join(self.ROOT, scene, subscene)} is empty." | |
num_imgs = len(rgb_paths) | |
images.extend( | |
[self.path2imgid(subscene, rgb_path) for rgb_path in rgb_paths] | |
) | |
id_range_subscenes.append((offset, offset + num_imgs)) | |
offset += num_imgs | |
# image collections | |
metadata = pickle.load( | |
open(os.path.join(self.ROOT, scene, "metadata.pkl"), "rb") | |
) | |
ref_imgs = list(metadata.keys()) | |
img_groups = [] | |
for ref_img in ref_imgs: | |
other_imgs = metadata[ref_img] | |
if len(other_imgs) + 1 < self.num_views: | |
continue | |
group = [(*other_img[0], other_img[1]) for other_img in other_imgs] | |
group.insert(0, (*ref_img, 1)) | |
img_groups.append(np.array(group)) | |
id_ranges.append(id_range_subscenes[ref_img[0]]) | |
scope.append(start) | |
start = start + len(group) | |
num_groups = len(img_groups) | |
sceneids.extend([j] * num_groups) | |
groups.extend(img_groups) | |
is_video.extend([False] * num_groups) | |
j += 1 | |
self.scenes = np.array(scenes) | |
self.sceneids = np.array(sceneids) | |
self.scope = np.array(scope) | |
self.video_flags = np.array(is_video) | |
self.groups = np.concatenate(groups, 0) | |
self.id_ranges = np.array(id_ranges) | |
self.images = np.array(images) | |
data = dict( | |
scenes=self.scenes, | |
sceneids=self.sceneids, | |
scope=self.scope, | |
video_flags=self.video_flags, | |
groups=self.groups, | |
id_ranges=self.id_ranges, | |
images=self.images, | |
) | |
with h5py.File(cache_file, "w") as h5f: | |
h5f.create_dataset( | |
"scenes", | |
data=data["scenes"].astype(object), | |
dtype=h5py.string_dtype(encoding="utf-8"), | |
compression="lzf", | |
chunks=True, | |
) | |
h5f.create_dataset( | |
"sceneids", data=data["sceneids"], compression="lzf", chunks=True | |
) | |
h5f.create_dataset( | |
"scope", data=data["scope"], compression="lzf", chunks=True | |
) | |
h5f.create_dataset( | |
"video_flags", | |
data=data["video_flags"], | |
compression="lzf", | |
chunks=True, | |
) | |
h5f.create_dataset( | |
"groups", data=data["groups"], compression="lzf", chunks=True | |
) | |
h5f.create_dataset( | |
"id_ranges", data=data["id_ranges"], compression="lzf", chunks=True | |
) | |
h5f.create_dataset( | |
"images", data=data["images"], compression="lzf", chunks=True | |
) | |
def __len__(self): | |
return len(self.scope) | |
def get_image_num(self): | |
return len(self.images) | |
def get_stats(self): | |
return f"{len(self)} groups of views" | |
def _get_views(self, idx, resolution, rng, num_views): | |
scene = self.scenes[self.sceneids[idx]] | |
if rng.random() < 0.6: | |
ids = np.arange(self.id_ranges[idx][0], self.id_ranges[idx][1]) | |
cut_off = num_views if not self.allow_repeat else max(num_views // 3, 3) | |
start_ids = ids[: len(ids) - cut_off + 1] | |
start_id = rng.choice(start_ids) | |
pos, ordered_video = self.get_seq_from_start_id( | |
num_views, | |
start_id, | |
ids.tolist(), | |
rng, | |
max_interval=self.max_interval, | |
video_prob=0.8, | |
fix_interval_prob=0.5, | |
block_shuffle=16, | |
) | |
ids = np.array(ids)[pos] | |
image_idxs = self.images[ids] | |
else: | |
ordered_video = False | |
seq_start_index = self.scope[idx] | |
seq_end_index = self.scope[idx + 1] if idx < len(self.scope) - 1 else None | |
image_idxs = ( | |
self.groups[seq_start_index:seq_end_index] | |
if seq_end_index is not None | |
else self.groups[seq_start_index:] | |
) | |
image_idxs, overlap_scores = image_idxs[:, :2], image_idxs[:, 2] | |
replace = ( | |
True | |
if self.allow_repeat | |
or len(overlap_scores[overlap_scores > 0]) < num_views | |
else False | |
) | |
image_idxs = rng.choice( | |
image_idxs, | |
num_views, | |
replace=replace, | |
p=overlap_scores / np.sum(overlap_scores), | |
) | |
image_idxs = image_idxs.astype(np.int64) | |
views = [] | |
for v, view_idx in enumerate(image_idxs): | |
img_path = self.imgid2path(view_idx, scene) | |
depth_path = img_path.replace("rgb", "depth").replace(".jpg", ".npy") | |
cam_path = img_path.replace("rgb", "cam").replace(".jpg", ".npz") | |
sky_mask_path = img_path.replace("rgb", "sky_mask") | |
image = imread_cv2(img_path) | |
depthmap = np.load(depth_path) | |
camera_params = np.load(cam_path) | |
sky_mask = cv2.imread(sky_mask_path, cv2.IMREAD_UNCHANGED) >= 127 | |
intrinsics = camera_params["intrinsic"].astype(np.float32) | |
camera_pose = camera_params["pose"].astype(np.float32) | |
depthmap[sky_mask] = -1.0 | |
depthmap[depthmap > 400.0] = 0.0 | |
depthmap = np.nan_to_num(depthmap, nan=0, posinf=0, neginf=0) | |
threshold = ( | |
np.percentile(depthmap[depthmap > 0], 98) | |
if depthmap[depthmap > 0].size > 0 | |
else 0 | |
) | |
depthmap[depthmap > threshold] = 0.0 | |
image, depthmap, intrinsics = self._crop_resize_if_necessary( | |
image, depthmap, intrinsics, resolution, rng, info=(img_path) | |
) | |
# generate img mask and raymap mask | |
img_mask, ray_mask = self.get_img_and_ray_masks( | |
self.is_metric, v, rng, p=[0.75, 0.2, 0.05] | |
) | |
views.append( | |
dict( | |
img=image, | |
depthmap=depthmap, | |
camera_pose=camera_pose, # cam2world | |
camera_intrinsics=intrinsics, | |
dataset="MapFree", | |
label=img_path, | |
is_metric=self.is_metric, | |
instance=img_path, | |
is_video=ordered_video, | |
quantile=np.array(0.96, dtype=np.float32), | |
img_mask=img_mask, | |
ray_mask=ray_mask, | |
camera_only=False, | |
depth_only=False, | |
single_view=False, | |
reset=False, | |
) | |
) | |
assert len(views) == num_views | |
return views | |