Spaces:
Runtime error
Runtime error
# Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# | |
# -------------------------------------------------------- | |
# base class for implementing datasets | |
# -------------------------------------------------------- | |
import PIL | |
import numpy as np | |
import torch | |
from eval.mv_recon.dataset_utils.transforms import ImgNorm | |
from dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates | |
import eval.mv_recon.dataset_utils.cropping as cropping | |
class BaseStereoViewDataset: | |
"""Define all basic options. | |
Usage: | |
class MyDataset (BaseStereoViewDataset): | |
def _get_views(self, idx, rng): | |
# overload here | |
views = [] | |
views.append(dict(img=, ...)) | |
return views | |
""" | |
def __init__( | |
self, | |
*, # only keyword arguments | |
split=None, | |
resolution=None, # square_size or (width, height) or list of [(width,height), ...] | |
transform=ImgNorm, | |
aug_crop=False, | |
seed=None, | |
): | |
self.num_views = 2 | |
self.split = split | |
self._set_resolutions(resolution) | |
self.transform = transform | |
if isinstance(transform, str): | |
transform = eval(transform) | |
self.aug_crop = aug_crop | |
self.seed = seed | |
def __len__(self): | |
return len(self.scenes) | |
def get_stats(self): | |
return f"{len(self)} pairs" | |
def __repr__(self): | |
resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]" | |
return ( | |
f"""{type(self).__name__}({self.get_stats()}, | |
{self.split=}, | |
{self.seed=}, | |
resolutions={resolutions_str}, | |
{self.transform=})""".replace( | |
"self.", "" | |
) | |
.replace("\n", "") | |
.replace(" ", "") | |
) | |
def _get_views(self, idx, resolution, rng): | |
raise NotImplementedError() | |
def __getitem__(self, idx): | |
if isinstance(idx, tuple): | |
# the idx is specifying the aspect-ratio | |
idx, ar_idx = idx | |
else: | |
assert len(self._resolutions) == 1 | |
ar_idx = 0 | |
# set-up the rng | |
if self.seed: # reseed for each __getitem__ | |
self._rng = np.random.default_rng(seed=self.seed + idx) | |
elif not hasattr(self, "_rng"): | |
seed = torch.initial_seed() # this is different for each dataloader process | |
self._rng = np.random.default_rng(seed=seed) | |
# over-loaded code | |
resolution = self._resolutions[ | |
ar_idx | |
] # DO NOT CHANGE THIS (compatible with BatchedRandomSampler) | |
views = self._get_views(idx, resolution, self._rng) | |
# check data-types | |
for v, view in enumerate(views): | |
assert ( | |
"pts3d" not in view | |
), f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}" | |
view["idx"] = v | |
# encode the image | |
width, height = view["img"].size | |
view["true_shape"] = np.int32((height, width)) | |
view["img"] = self.transform(view["img"]) | |
assert "camera_intrinsics" in view | |
if "camera_pose" not in view: | |
view["camera_pose"] = np.full((4, 4), np.nan, dtype=np.float32) | |
else: | |
assert np.isfinite( | |
view["camera_pose"] | |
).all(), f"NaN in camera pose for view {view_name(view)}" | |
assert "pts3d" not in view | |
assert "valid_mask" not in view | |
assert np.isfinite( | |
view["depthmap"] | |
).all(), f"NaN in depthmap for view {view_name(view)}" | |
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view) | |
view["pts3d"] = pts3d | |
view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1) | |
# check all datatypes | |
for key, val in view.items(): | |
res, err_msg = is_good_type(key, val) | |
assert res, f"{err_msg} with {key}={val} for view {view_name(view)}" | |
K = view["camera_intrinsics"] | |
view["img_mask"] = True | |
view["ray_mask"] = False | |
view["ray_map"] = torch.full( | |
(6, view["img"].shape[-2], view["img"].shape[-1]), torch.nan | |
) | |
view["update"] = True | |
view["reset"] = False | |
# last thing done! | |
for view in views: | |
# transpose to make sure all views are the same size | |
transpose_to_landscape(view) | |
# this allows to check whether the RNG is is the same state each time | |
view["rng"] = int.from_bytes(self._rng.bytes(4), "big") | |
return views | |
def _set_resolutions(self, resolutions): | |
"""Set the resolution(s) of the dataset. | |
Params: | |
- resolutions: int or tuple or list of tuples | |
""" | |
assert resolutions is not None, "undefined resolution" | |
if not isinstance(resolutions, list): | |
resolutions = [resolutions] | |
self._resolutions = [] | |
for resolution in resolutions: | |
if isinstance(resolution, int): | |
width = height = resolution | |
else: | |
width, height = resolution | |
assert isinstance( | |
width, int | |
), f"Bad type for {width=} {type(width)=}, should be int" | |
assert isinstance( | |
height, int | |
), f"Bad type for {height=} {type(height)=}, should be int" | |
assert width >= height | |
self._resolutions.append((width, height)) | |
def _crop_resize_if_necessary( | |
self, image, depthmap, intrinsics, resolution, rng=None, info=None | |
): | |
"""This function: | |
- first downsizes the image with LANCZOS inteprolation, | |
which is better than bilinear interpolation in | |
""" | |
if not isinstance(image, PIL.Image.Image): | |
image = PIL.Image.fromarray(image) | |
# downscale with lanczos interpolation so that image.size == resolution | |
# cropping centered on the principal point | |
W, H = image.size | |
cx, cy = intrinsics[:2, 2].round().astype(int) | |
# calculate min distance to margin | |
min_margin_x = min(cx, W - cx) | |
min_margin_y = min(cy, H - cy) | |
assert min_margin_x > W / 5, f"Bad principal point in view={info}" | |
assert min_margin_y > H / 5, f"Bad principal point in view={info}" | |
## Center crop | |
# Crop on the principal point, make it always centered | |
# the new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) | |
l, t = cx - min_margin_x, cy - min_margin_y | |
r, b = cx + min_margin_x, cy + min_margin_y | |
crop_bbox = (l, t, r, b) | |
image, depthmap, intrinsics = cropping.crop_image_depthmap( | |
image, depthmap, intrinsics, crop_bbox | |
) | |
# # transpose the resolution if necessary | |
W, H = image.size # new size | |
assert resolution[0] >= resolution[1] | |
if H > 1.1 * W: | |
# image is portrait mode | |
resolution = resolution[::-1] | |
elif 0.9 < H / W < 1.1 and resolution[0] != resolution[1]: | |
# image is square, so we chose (portrait, landscape) randomly | |
if rng.integers(2): | |
resolution = resolution[::-1] | |
# high-quality Lanczos down-scaling | |
target_resolution = np.array(resolution) | |
# # if self.aug_crop > 1: | |
# # target_resolution += rng.integers(0, self.aug_crop) | |
# if resolution != (224, 224): | |
# halfw, halfh = ((2*(W//2))//16)*8, ((2*(H//2))//16)*8 | |
# ## Recale with max factor, so one of width or height might be larger than target_resolution | |
# image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, (2*halfw, 2*halfh)) | |
# else: | |
image, depthmap, intrinsics = cropping.rescale_image_depthmap( | |
image, depthmap, intrinsics, target_resolution | |
) | |
# actual cropping (if necessary) with bilinear interpolation | |
# if resolution == (224, 224): | |
intrinsics2 = cropping.camera_matrix_of_crop( | |
intrinsics, image.size, resolution, offset_factor=0.5 | |
) | |
crop_bbox = cropping.bbox_from_intrinsics_in_out( | |
intrinsics, intrinsics2, resolution | |
) | |
image, depthmap, intrinsics = cropping.crop_image_depthmap( | |
image, depthmap, intrinsics, crop_bbox | |
) | |
return image, depthmap, intrinsics | |
def is_good_type(key, v): | |
"""returns (is_good, err_msg)""" | |
if isinstance(v, (str, int, tuple)): | |
return True, None | |
if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8): | |
return False, f"bad {v.dtype=}" | |
return True, None | |
def view_name(view, batch_index=None): | |
def sel(x): | |
return x[batch_index] if batch_index not in (None, slice(None)) else x | |
db = sel(view["dataset"]) | |
label = sel(view["label"]) | |
instance = sel(view["instance"]) | |
return f"{db}/{label}/{instance}" | |
def transpose_to_landscape(view): | |
height, width = view["true_shape"] | |
if width < height: | |
# rectify portrait to landscape | |
assert view["img"].shape == (3, height, width) | |
view["img"] = view["img"].swapaxes(1, 2) | |
assert view["valid_mask"].shape == (height, width) | |
view["valid_mask"] = view["valid_mask"].swapaxes(0, 1) | |
assert view["depthmap"].shape == (height, width) | |
view["depthmap"] = view["depthmap"].swapaxes(0, 1) | |
assert view["pts3d"].shape == (height, width, 3) | |
view["pts3d"] = view["pts3d"].swapaxes(0, 1) | |
# transpose x and y pixels | |
view["camera_intrinsics"] = view["camera_intrinsics"][[1, 0, 2]] | |