Last commit not found
# Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# | |
# -------------------------------------------------------- | |
# Dataloader for Spring | |
# -------------------------------------------------------- | |
import os.path as osp | |
from glob import glob | |
import itertools | |
import numpy as np | |
import re | |
import cv2 | |
import os | |
from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset | |
from dust3r.utils.image import imread_cv2 | |
def readPFM(file): | |
file = open(file, 'rb') | |
color = None | |
width = None | |
height = None | |
scale = None | |
endian = None | |
header = file.readline().rstrip() | |
if header == b'PF': | |
color = True | |
elif header == b'Pf': | |
color = False | |
else: | |
raise Exception('Not a PFM file.') | |
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) | |
if dim_match: | |
width, height = map(int, dim_match.groups()) | |
else: | |
raise Exception('Malformed PFM header.') | |
scale = float(file.readline().rstrip()) | |
if scale < 0: # little-endian | |
endian = '<' | |
scale = -scale | |
else: | |
endian = '>' # big-endian | |
data = np.fromfile(file, endian + 'f') | |
shape = (height, width, 3) if color else (height, width) | |
data = np.reshape(data, shape) | |
data = np.flipud(data) | |
return data | |
class SpringDatasets(BaseStereoViewDataset): | |
def __init__(self, *args, split, ROOT, **kwargs): | |
self.ROOT = ROOT # ROOT = "/media/tyhuang/T9/videodepth_data/spring_proc/train" | |
super().__init__(*args, **kwargs) | |
self.dataset_label = 'Spring' | |
test_scenes = [] | |
scene_list = [] | |
for scene in os.listdir(ROOT): | |
if scene not in test_scenes and split == 'train': | |
scene_list.append(osp.join(ROOT, scene)) | |
if scene in test_scenes and split == 'test': | |
scene_list.append(osp.join(ROOT, scene)) | |
self.pair_dict = {} | |
pair_num = 0 | |
for scene in scene_list: | |
imgs = sorted(glob(osp.join(scene, '*_rgb.jpg'))) | |
len_imgs = len(imgs) | |
# combinations = [(i, j) for i, j in itertools.combinations(range(len_imgs), 2) | |
# if abs(i - j) <= 20 or (abs(i - j) <= 60 and abs(i - j) % 3 == 0)] | |
combinations = [(i, j) for i, j in itertools.combinations(range(len_imgs), 2) if abs(i - j) <= 10 ] | |
for (i, j) in combinations: | |
self.pair_dict[pair_num] = [imgs[i], imgs[j]] | |
pair_num += 1 | |
def __len__(self): | |
return len(self.pair_dict) | |
def _get_views(self, idx, resolution, rng): | |
views = [] | |
for img_path in self.pair_dict[idx]: | |
rgb_image = imread_cv2(img_path) | |
depthmap_path = img_path.replace('_rgb.jpg', '_depth.pfm') | |
mask_path = img_path.replace('_rgb.jpg', '_mask.png') | |
metadata_path = img_path.replace('_rgb.jpg', '_metadata.npz') | |
pred_depth = np.load(img_path.replace('.jpg', '_pred_depth_' + self.depth_prior_name + '.npz'))#['depth'] | |
focal_length_px = pred_depth['focallength_px'] | |
pred_depth = pred_depth['depth'] | |
pred_depth = self.pixel_to_pointcloud(pred_depth, focal_length_px) | |
depthmap = readPFM(depthmap_path) | |
#scale = depthmap.min()+depthmap.min() | |
maskmap = imread_cv2(mask_path, cv2.IMREAD_UNCHANGED).astype(np.float32) | |
maskmap = (maskmap / 255.0) > 0.1 | |
#maskmap = maskmap * (depthmap<100) | |
depthmap *= maskmap | |
metadata = np.load(metadata_path) | |
intrinsics = np.float32(metadata['camera_intrinsics']) | |
camera_pose = np.float32(metadata['camera_pose']) | |
# max_depth = np.float32(metadata['maximum_depth']) | |
# depthmap = (depthmap.astype(np.float32) / 200.0) | |
# pred_depth = pred_depth/200.0 | |
# camera_pose[:3, 3] /= 200.0 | |
rgb_image, depthmap, pred_depth, intrinsics = self._crop_resize_if_necessary( | |
rgb_image, depthmap, pred_depth, intrinsics, resolution, rng=rng, info=img_path) | |
num_valid = (depthmap > 0.0).sum() | |
# assert num_valid > 0 | |
# if num_valid==0: | |
# depthmap +=0.001 | |
views.append(dict( | |
img=rgb_image, | |
depthmap=depthmap, | |
camera_pose=camera_pose, | |
camera_intrinsics=intrinsics, | |
dataset=self.dataset_label, | |
label=img_path, | |
instance=img_path, | |
pred_depth=pred_depth | |
)) | |
return views | |
if __name__ == "__main__": | |
from dust3r.datasets.base.base_stereo_view_dataset import view_name | |
from dust3r.viz import SceneViz, auto_cam_size | |
from dust3r.utils.image import rgb | |
dataset = SpringDatasets(split='train', ROOT="/media/8TB/tyhuang/video_depth/spring_proc/train", resolution=512, aug_crop=16) | |
a = len(dataset) | |
for idx in np.random.permutation(len(dataset)): | |
views = dataset[idx] | |
assert len(views) == 2 | |
print(view_name(views[0]), view_name(views[1])) | |
viz = SceneViz() | |
poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]] | |
cam_size = max(auto_cam_size(poses), 0.001) | |
for view_idx in [0, 1]: | |
pts3d = views[view_idx]['pts3d'] | |
valid_mask = views[view_idx]['valid_mask'] | |
colors = rgb(views[view_idx]['img']) | |
viz.add_pointcloud(pts3d, colors, valid_mask) | |
viz.add_camera(pose_c2w=views[view_idx]['camera_pose'], | |
focal=views[view_idx]['camera_intrinsics'][0, 0], | |
color=(idx * 255, (1 - idx) * 255, 0), | |
image=colors, | |
cam_size=cam_size) | |
viz.show() |