liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
import os.path as osp
import json
import itertools
from collections import deque
import sys
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
import cv2
import numpy as np
import time
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
from dust3r.utils.image import imread_cv2
class Co3d_Multi(BaseMultiViewDataset):
def __init__(self, mask_bg="rand", *args, ROOT, **kwargs):
self.ROOT = ROOT
super().__init__(*args, **kwargs)
assert mask_bg in (True, False, "rand")
self.mask_bg = mask_bg
self.is_metric = False
self.dataset_label = "Co3d_v2"
# load all scenes
with open(osp.join(self.ROOT, f"selected_seqs_{self.split}.json"), "r") as f:
self.scenes = json.load(f)
self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0}
self.scenes = {
(k, k2): v2 for k, v in self.scenes.items() for k2, v2 in v.items()
}
self.scene_list = list(self.scenes.keys())
cut_off = (
self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
)
self.cut_off = cut_off
self.all_ref_imgs = [
(key, value)
for key, values in self.scenes.items()
for value in values[: len(values) - cut_off + 1]
]
self.invalidate = {scene: {} for scene in self.scene_list}
self.invalid_scenes = {scene: False for scene in self.scene_list}
def __len__(self):
return len(self.all_ref_imgs)
def _get_metadatapath(self, obj, instance, view_idx):
return osp.join(self.ROOT, obj, instance, "images", f"frame{view_idx:06n}.npz")
def _get_impath(self, obj, instance, view_idx):
return osp.join(self.ROOT, obj, instance, "images", f"frame{view_idx:06n}.jpg")
def _get_depthpath(self, obj, instance, view_idx):
return osp.join(
self.ROOT, obj, instance, "depths", f"frame{view_idx:06n}.jpg.geometric.png"
)
def _get_maskpath(self, obj, instance, view_idx):
return osp.join(self.ROOT, obj, instance, "masks", f"frame{view_idx:06n}.png")
def _read_depthmap(self, depthpath, input_metadata):
depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(
input_metadata["maximum_depth"]
)
return depthmap
def _get_views(self, idx, resolution, rng, num_views):
invalid_seq = True
scene_info, ref_img_idx = self.all_ref_imgs[idx]
while invalid_seq:
while self.invalid_scenes[scene_info]:
idx = rng.integers(low=0, high=len(self.all_ref_imgs))
scene_info, ref_img_idx = self.all_ref_imgs[idx]
obj, instance = scene_info
image_pool = self.scenes[obj, instance]
if len(image_pool) < self.cut_off:
print("Invalid scene!")
self.invalid_scenes[scene_info] = True
continue
imgs_idxs, ordered_video = self.get_seq_from_start_id(
num_views, ref_img_idx, image_pool, rng
)
if resolution not in self.invalidate[obj, instance]: # flag invalid images
self.invalidate[obj, instance][resolution] = [
False for _ in range(len(image_pool))
]
# decide now if we mask the bg
mask_bg = (self.mask_bg == True) or (
self.mask_bg == "rand" and rng.choice(2, p=[0.9, 0.1])
)
views = []
imgs_idxs = deque(imgs_idxs)
while len(imgs_idxs) > 0: # some images (few) have zero depth
if (
len(image_pool) - sum(self.invalidate[obj, instance][resolution])
< self.cut_off
):
print("Invalid scene!")
invalid_seq = True
self.invalid_scenes[scene_info] = True
break
im_idx = imgs_idxs.pop()
if self.invalidate[obj, instance][resolution][im_idx]:
# search for a valid image
ordered_video = False
random_direction = 2 * rng.choice(2) - 1
for offset in range(1, len(image_pool)):
tentative_im_idx = (im_idx + (random_direction * offset)) % len(
image_pool
)
if not self.invalidate[obj, instance][resolution][
tentative_im_idx
]:
im_idx = tentative_im_idx
break
view_idx = image_pool[im_idx]
impath = self._get_impath(obj, instance, view_idx)
depthpath = self._get_depthpath(obj, instance, view_idx)
# load camera params
metadata_path = self._get_metadatapath(obj, instance, view_idx)
input_metadata = np.load(metadata_path)
camera_pose = input_metadata["camera_pose"].astype(np.float32)
intrinsics = input_metadata["camera_intrinsics"].astype(np.float32)
# load image and depth
rgb_image = imread_cv2(impath)
depthmap = self._read_depthmap(depthpath, input_metadata)
if mask_bg:
# load object mask
maskpath = self._get_maskpath(obj, instance, view_idx)
maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(
np.float32
)
maskmap = (maskmap / 255.0) > 0.1
# update the depthmap with mask
depthmap *= maskmap
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath
)
num_valid = (depthmap > 0.0).sum()
if num_valid == 0:
# problem, invalidate image and retry
self.invalidate[obj, instance][resolution][im_idx] = True
imgs_idxs.append(im_idx)
continue
# generate img mask and raymap mask
img_mask, ray_mask = self.get_img_and_ray_masks(
self.is_metric, len(views), rng
)
views.append(
dict(
img=rgb_image,
depthmap=depthmap,
camera_pose=camera_pose,
camera_intrinsics=intrinsics,
dataset=self.dataset_label,
label=osp.join(obj, instance),
instance=osp.split(impath)[1],
is_metric=self.is_metric,
is_video=ordered_video,
quantile=np.array(0.9, dtype=np.float32),
img_mask=img_mask,
ray_mask=ray_mask,
camera_only=False,
depth_only=False,
single_view=False,
reset=False,
)
)
if len(views) == num_views and not all(
[view["instance"] == views[0]["instance"] for view in views]
):
invalid_seq = False
return views