|
import os |
|
import json |
|
import torch |
|
import torchvision.transforms as transforms |
|
import os.path |
|
import numpy as np |
|
import cv2 |
|
from torch.utils.data import Dataset |
|
import random |
|
from .__base_dataset__ import BaseDataset |
|
|
|
|
|
class ScanNetDataset(BaseDataset): |
|
def __init__(self, cfg, phase, **kwargs): |
|
super(ScanNetDataset, self).__init__( |
|
cfg=cfg, |
|
phase=phase, |
|
**kwargs) |
|
self.metric_scale = cfg.metric_scale |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_data_for_test(self, idx: int, test_mode=True): |
|
anno = self.annotations['files'][idx] |
|
meta_data = self.load_meta_data(anno) |
|
data_path = self.load_data_path(meta_data) |
|
data_batch = self.load_batch(meta_data, data_path, test_mode) |
|
|
|
curr_rgb, curr_depth, curr_normal, curr_cam_model = data_batch['curr_rgb'], data_batch['curr_depth'], data_batch['curr_normal'], data_batch['curr_cam_model'] |
|
ori_curr_intrinsic = meta_data['cam_in'] |
|
|
|
|
|
transform_paras = dict() |
|
rgbs, depths, intrinsics, cam_models, _, other_labels, transform_paras = self.img_transforms( |
|
images=[curr_rgb,], |
|
labels=[curr_depth, ], |
|
intrinsics=[ori_curr_intrinsic, ], |
|
cam_models=[curr_cam_model, ], |
|
transform_paras=transform_paras) |
|
|
|
depth_out = self.clip_depth(curr_depth) * self.depth_range[1] |
|
inv_depth = self.depth2invdepth(depth_out, np.zeros_like(depth_out, dtype=np.bool)) |
|
filename = os.path.basename(meta_data['rgb'])[:-4] + '.jpg' |
|
curr_intrinsic_mat = self.intrinsics_list2mat(intrinsics[0]) |
|
|
|
pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0] |
|
scale_ratio = transform_paras['label_scale_factor'] if 'label_scale_factor' in transform_paras else 1.0 |
|
cam_models_stacks = [ |
|
torch.nn.functional.interpolate(cam_models[0][None, :, :, :], size=(cam_models[0].shape[1]//i, cam_models[0].shape[2]//i), mode='bilinear', align_corners=False).squeeze() |
|
for i in [2, 4, 8, 16, 32] |
|
] |
|
raw_rgb = torch.from_numpy(curr_rgb) |
|
curr_normal = torch.from_numpy(curr_normal.transpose((2,0,1))) |
|
|
|
|
|
data = dict(input=rgbs[0], |
|
target=depth_out, |
|
intrinsic=curr_intrinsic_mat, |
|
filename=filename, |
|
dataset=self.data_name, |
|
cam_model=cam_models_stacks, |
|
pad=pad, |
|
scale=scale_ratio, |
|
raw_rgb=raw_rgb, |
|
sample_id=idx, |
|
data_path=meta_data['rgb'], |
|
inv_depth=inv_depth, |
|
normal=curr_normal, |
|
) |
|
return data |
|
|
|
def get_data_for_trainval(self, idx: int): |
|
anno = self.annotations['files'][idx] |
|
meta_data = self.load_meta_data(anno) |
|
|
|
data_path = self.load_data_path(meta_data) |
|
data_batch = self.load_batch(meta_data, data_path, test_mode=False) |
|
|
|
|
|
|
|
|
|
curr_rgb, curr_depth, curr_normal, curr_sem, curr_cam_model = data_batch['curr_rgb'], data_batch['curr_depth'], data_batch['curr_normal'], data_batch['curr_sem'], data_batch['curr_cam_model'] |
|
|
|
|
|
|
|
if 'curr_stereo_depth' in data_batch.keys(): |
|
curr_stereo_depth = data_batch['curr_stereo_depth'] |
|
else: |
|
curr_stereo_depth = self.load_stereo_depth_label(None, H=curr_rgb.shape[0], W=curr_rgb.shape[1]) |
|
|
|
curr_intrinsic = meta_data['cam_in'] |
|
|
|
transform_paras = dict(random_crop_size = self.random_crop_size) |
|
assert curr_rgb.shape[:2] == curr_depth.shape == curr_normal.shape[:2] == curr_sem.shape |
|
rgbs, depths, intrinsics, cam_models, normals, other_labels, transform_paras = self.img_transforms( |
|
images=[curr_rgb, ], |
|
labels=[curr_depth, ], |
|
intrinsics=[curr_intrinsic,], |
|
cam_models=[curr_cam_model, ], |
|
normals = [curr_normal, ], |
|
other_labels=[curr_sem, curr_stereo_depth], |
|
transform_paras=transform_paras) |
|
|
|
sem_mask = other_labels[0].int() |
|
|
|
depth_out = self.normalize_depth(depths[0]) |
|
|
|
depth_out[sem_mask==142] = -1 |
|
|
|
inv_depth = self.depth2invdepth(depth_out, sem_mask==142) |
|
filename = os.path.basename(meta_data['rgb'])[:-4] + '.jpg' |
|
curr_intrinsic_mat = self.intrinsics_list2mat(intrinsics[0]) |
|
cam_models_stacks = [ |
|
torch.nn.functional.interpolate(cam_models[0][None, :, :, :], size=(cam_models[0].shape[1]//i, cam_models[0].shape[2]//i), mode='bilinear', align_corners=False).squeeze() |
|
for i in [2, 4, 8, 16, 32] |
|
] |
|
|
|
|
|
stereo_depth_pre_trans = other_labels[1] * (other_labels[1] > 0.3) * (other_labels[1] < 200) |
|
stereo_depth = stereo_depth_pre_trans * transform_paras['label_scale_factor'] |
|
stereo_depth = self.normalize_depth(stereo_depth) |
|
|
|
pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0] |
|
data = dict(input=rgbs[0], |
|
target=depth_out, |
|
intrinsic=curr_intrinsic_mat, |
|
filename=filename, |
|
dataset=self.data_name, |
|
cam_model=cam_models_stacks, |
|
pad=torch.tensor(pad), |
|
data_type=[self.data_type, ], |
|
sem_mask=sem_mask.int(), |
|
stereo_depth= stereo_depth, |
|
normal=normals[0], |
|
inv_depth=inv_depth, |
|
scale=transform_paras['label_scale_factor']) |
|
return data |
|
|
|
def load_batch(self, meta_data, data_path, test_mode): |
|
|
|
|
|
|
|
|
|
|
|
|
|
curr_intrinsic = meta_data['cam_in'] |
|
|
|
curr_rgb, curr_depth = self.load_rgb_depth(data_path['rgb_path'], data_path['depth_path'], test_mode) |
|
|
|
curr_sem = self.load_sem_label(data_path['sem_path'], curr_depth) |
|
|
|
curr_cam_model = self.create_cam_model(curr_rgb.shape[0], curr_rgb.shape[1], curr_intrinsic) |
|
|
|
curr_normal = self.load_norm_label(data_path['normal_path'], H=curr_rgb.shape[0], W=curr_rgb.shape[1], test_mode=test_mode) |
|
|
|
depth_mask = self.load_depth_valid_mask(data_path['depth_mask_path']) |
|
curr_depth[~depth_mask] = -1 |
|
|
|
curr_stereo_depth = self.load_stereo_depth_label(data_path['disp_path'], H=curr_rgb.shape[0], W=curr_rgb.shape[1]) |
|
|
|
data_batch = dict( |
|
curr_rgb = curr_rgb, |
|
curr_depth = curr_depth, |
|
curr_sem = curr_sem, |
|
curr_normal = curr_normal, |
|
curr_cam_model=curr_cam_model, |
|
curr_stereo_depth=curr_stereo_depth, |
|
) |
|
return data_batch |
|
|
|
def load_rgb_depth(self, rgb_path: str, depth_path: str, test_mode: bool): |
|
""" |
|
Load the rgb and depth map with the paths. |
|
""" |
|
rgb = self.load_data(rgb_path, is_rgb_img=True) |
|
if rgb is None: |
|
self.logger.info(f'>>>>{rgb_path} has errors.') |
|
|
|
depth = self.load_data(depth_path) |
|
if depth is None: |
|
self.logger.info(f'{depth_path} has errors.') |
|
|
|
|
|
|
|
|
|
|
|
depth = depth.astype(np.float) |
|
|
|
|
|
|
|
|
|
depth = self.process_depth(depth, rgb, test_mode) |
|
return rgb, depth |
|
|
|
def process_depth(self, depth, rgb, test_mode=False): |
|
depth[depth>65500] = 0 |
|
depth /= self.metric_scale |
|
h, w, _ = rgb.shape |
|
if test_mode==False: |
|
depth = cv2.resize(depth, (w, h), interpolation=cv2.INTER_NEAREST) |
|
return depth |
|
|
|
def load_norm_label(self, norm_path, H, W, test_mode): |
|
|
|
if norm_path is None: |
|
norm_gt = np.zeros((H, W, 3)).astype(np.float32) |
|
else: |
|
norm_gt = cv2.imread(norm_path) |
|
norm_gt = cv2.cvtColor(norm_gt, cv2.COLOR_BGR2RGB) |
|
|
|
norm_gt = np.array(norm_gt).astype(np.uint8) |
|
|
|
mask_path = 'orient-mask'.join(norm_path.rsplit('normal', 1)) |
|
mask_gt = cv2.imread(mask_path) |
|
mask_gt = np.array(mask_gt).astype(np.uint8) |
|
valid_mask = np.logical_not( |
|
np.logical_and( |
|
np.logical_and( |
|
mask_gt[:, :, 0] == 0, mask_gt[:, :, 1] == 0), |
|
mask_gt[:, :, 2] == 0)) |
|
valid_mask = valid_mask[:, :, np.newaxis] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
norm_gt = ((norm_gt.astype(np.float32) / 255.0) * 2.0) - 1.0 |
|
norm_valid_mask = (np.linalg.norm(norm_gt, axis=2, keepdims=True) > 0.5) * valid_mask |
|
norm_gt = norm_gt * norm_valid_mask |
|
|
|
if test_mode==False: |
|
norm_gt = cv2.resize(norm_gt, (W, H), interpolation=cv2.INTER_NEAREST) |
|
|
|
return norm_gt |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
from mmcv.utils import Config |
|
cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py') |
|
dataset_i = NYUDataset(cfg['Apolloscape'], 'train', **cfg.data_basic) |
|
print(dataset_i) |
|
|