Spaces:
Runtime error
Runtime error
File size: 7,643 Bytes
2df809d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import os.path as osp
import json
import itertools
from collections import deque
import sys
sys.path.append(osp.join(osp.dirname(__file__), "..", ".."))
import cv2
import numpy as np
import time
from dust3r.datasets.base.base_multiview_dataset import BaseMultiViewDataset
from dust3r.utils.image import imread_cv2
class Co3d_Multi(BaseMultiViewDataset):
def __init__(self, mask_bg="rand", *args, ROOT, **kwargs):
self.ROOT = ROOT
super().__init__(*args, **kwargs)
assert mask_bg in (True, False, "rand")
self.mask_bg = mask_bg
self.is_metric = False
self.dataset_label = "Co3d_v2"
# load all scenes
with open(osp.join(self.ROOT, f"selected_seqs_{self.split}.json"), "r") as f:
self.scenes = json.load(f)
self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0}
self.scenes = {
(k, k2): v2 for k, v in self.scenes.items() for k2, v2 in v.items()
}
self.scene_list = list(self.scenes.keys())
cut_off = (
self.num_views if not self.allow_repeat else max(self.num_views // 3, 3)
)
self.cut_off = cut_off
self.all_ref_imgs = [
(key, value)
for key, values in self.scenes.items()
for value in values[: len(values) - cut_off + 1]
]
self.invalidate = {scene: {} for scene in self.scene_list}
self.invalid_scenes = {scene: False for scene in self.scene_list}
def __len__(self):
return len(self.all_ref_imgs)
def _get_metadatapath(self, obj, instance, view_idx):
return osp.join(self.ROOT, obj, instance, "images", f"frame{view_idx:06n}.npz")
def _get_impath(self, obj, instance, view_idx):
return osp.join(self.ROOT, obj, instance, "images", f"frame{view_idx:06n}.jpg")
def _get_depthpath(self, obj, instance, view_idx):
return osp.join(
self.ROOT, obj, instance, "depths", f"frame{view_idx:06n}.jpg.geometric.png"
)
def _get_maskpath(self, obj, instance, view_idx):
return osp.join(self.ROOT, obj, instance, "masks", f"frame{view_idx:06n}.png")
def _read_depthmap(self, depthpath, input_metadata):
depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(
input_metadata["maximum_depth"]
)
return depthmap
def _get_views(self, idx, resolution, rng, num_views):
invalid_seq = True
scene_info, ref_img_idx = self.all_ref_imgs[idx]
while invalid_seq:
while self.invalid_scenes[scene_info]:
idx = rng.integers(low=0, high=len(self.all_ref_imgs))
scene_info, ref_img_idx = self.all_ref_imgs[idx]
obj, instance = scene_info
image_pool = self.scenes[obj, instance]
if len(image_pool) < self.cut_off:
print("Invalid scene!")
self.invalid_scenes[scene_info] = True
continue
imgs_idxs, ordered_video = self.get_seq_from_start_id(
num_views, ref_img_idx, image_pool, rng
)
if resolution not in self.invalidate[obj, instance]: # flag invalid images
self.invalidate[obj, instance][resolution] = [
False for _ in range(len(image_pool))
]
# decide now if we mask the bg
mask_bg = (self.mask_bg == True) or (
self.mask_bg == "rand" and rng.choice(2, p=[0.9, 0.1])
)
views = []
imgs_idxs = deque(imgs_idxs)
while len(imgs_idxs) > 0: # some images (few) have zero depth
if (
len(image_pool) - sum(self.invalidate[obj, instance][resolution])
< self.cut_off
):
print("Invalid scene!")
invalid_seq = True
self.invalid_scenes[scene_info] = True
break
im_idx = imgs_idxs.pop()
if self.invalidate[obj, instance][resolution][im_idx]:
# search for a valid image
ordered_video = False
random_direction = 2 * rng.choice(2) - 1
for offset in range(1, len(image_pool)):
tentative_im_idx = (im_idx + (random_direction * offset)) % len(
image_pool
)
if not self.invalidate[obj, instance][resolution][
tentative_im_idx
]:
im_idx = tentative_im_idx
break
view_idx = image_pool[im_idx]
impath = self._get_impath(obj, instance, view_idx)
depthpath = self._get_depthpath(obj, instance, view_idx)
# load camera params
metadata_path = self._get_metadatapath(obj, instance, view_idx)
input_metadata = np.load(metadata_path)
camera_pose = input_metadata["camera_pose"].astype(np.float32)
intrinsics = input_metadata["camera_intrinsics"].astype(np.float32)
# load image and depth
rgb_image = imread_cv2(impath)
depthmap = self._read_depthmap(depthpath, input_metadata)
if mask_bg:
# load object mask
maskpath = self._get_maskpath(obj, instance, view_idx)
maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(
np.float32
)
maskmap = (maskmap / 255.0) > 0.1
# update the depthmap with mask
depthmap *= maskmap
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath
)
num_valid = (depthmap > 0.0).sum()
if num_valid == 0:
# problem, invalidate image and retry
self.invalidate[obj, instance][resolution][im_idx] = True
imgs_idxs.append(im_idx)
continue
# generate img mask and raymap mask
img_mask, ray_mask = self.get_img_and_ray_masks(
self.is_metric, len(views), rng
)
views.append(
dict(
img=rgb_image,
depthmap=depthmap,
camera_pose=camera_pose,
camera_intrinsics=intrinsics,
dataset=self.dataset_label,
label=osp.join(obj, instance),
instance=osp.split(impath)[1],
is_metric=self.is_metric,
is_video=ordered_video,
quantile=np.array(0.9, dtype=np.float32),
img_mask=img_mask,
ray_mask=ray_mask,
camera_only=False,
depth_only=False,
single_view=False,
reset=False,
)
)
if len(views) == num_views and not all(
[view["instance"] == views[0]["instance"] for view in views]
):
invalid_seq = False
return views
|