File size: 7,025 Bytes
3ef1661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os
import json
import torch
import torchvision.transforms as transforms
import os.path
import numpy as np
import cv2
from torch.utils.data import Dataset
import random
import copy
from .__base_dataset__ import BaseDataset
import mono.utils.transform as img_transform


class AnyDataset(BaseDataset):
    def __init__(self, cfg, phase, **kwargs):        
        super(AnyDataset, self).__init__(
            cfg=cfg,
            phase=phase,
            **kwargs)

        self.cfg = cfg
        self.phase = phase
        self.mldb_info = kwargs['mldb_info']
        
        # root dir for data
        self.data_root = os.path.join(self.mldb_info['mldb_root'], self.mldb_info['data_root'])
        # depth/disp data root
        disp_root = self.mldb_info['disp_root'] if 'disp_root' in self.mldb_info else None
        self.disp_root = os.path.join(self.mldb_info['mldb_root'], disp_root) if disp_root is not None else None
        depth_root = self.mldb_info['depth_root'] if 'depth_root' in self.mldb_info else None
        self.depth_root = os.path.join(self.mldb_info['mldb_root'], depth_root) if depth_root is not None \
            else self.data_root
        # meta data root
        meta_data_root = self.mldb_info['meta_data_root'] if 'meta_data_root' in self.mldb_info else None
        self.meta_data_root = os.path.join(self.mldb_info['mldb_root'], meta_data_root) if meta_data_root is not None \
            else None
        # semantic segmentation labels root
        sem_root = self.mldb_info['semantic_root'] if 'semantic_root' in self.mldb_info else None
        self.sem_root = os.path.join(self.mldb_info['mldb_root'], sem_root) if sem_root is not None \
            else None

        # data annotations path
        self.data_annos_path = '/yvan1/data/NuScenes/NuScenes/annotations/train_ring_annotations.json'  # fill this 

        # load annotations
        annotations = self.load_annotations()
        whole_data_size = len(annotations['files']) 
               
        cfg_sample_ratio = cfg.data[phase].sample_ratio 
        cfg_sample_size = int(cfg.data[phase].sample_size)
        self.sample_size = int(whole_data_size * cfg_sample_ratio) if cfg_sample_size == -1 \
                           else (cfg_sample_size if cfg_sample_size < whole_data_size else whole_data_size)
        sample_list_of_whole_data = list(range(whole_data_size))[:self.sample_size]
        self.data_size = self.sample_size
        sample_list_of_whole_data = random.sample(list(range(whole_data_size)), whole_data_size)
        self.annotations = {'files': [annotations['files'][i] for i in sample_list_of_whole_data]}
        self.sample_list = list(range(self.data_size))
        
        # config transforms for the input and label
        self.transforms_cfg = cfg.data[phase]['pipeline']
        self.transforms_lib = 'mono.utils.transform.'

        self.img_file_type = ['.png', '.jpg', '.jpeg', '.bmp', '.tif']
        self.np_file_type = ['.npz', '.npy']

        # update canonical sparce information
        self.data_basic = copy.deepcopy(kwargs)
        canonical = self.data_basic.pop('canonical_space')
        self.data_basic.update(canonical)
        self.depth_range = kwargs['depth_range'] # predefined depth range for the network
        self.clip_depth_range = kwargs['clip_depth_range'] # predefined depth range for data processing
        self.depth_normalize = kwargs['depth_normalize']
        
        self.img_transforms = img_transform.Compose(self.build_data_transforms())
        self.EPS = 1e-8

        self.tmpl_info = ['rgb_sr', 'rgb_pre', 'rgb_next']

        # dataset info
        self.data_name = cfg.data_name
        self.data_type = cfg.data_type # there are mainly four types, i.e. ['rel', 'sfm', 'stereo', 'lidar']

    def __getitem__(self, idx: int) -> dict:
        return self.get_data_for_test(idx)

    def get_data_for_test(self, idx: int):
        # basic info
        anno = self.annotations['files'][idx]
        curr_rgb_path = os.path.join(self.data_root, anno['CAM_FRONT_RIGHT']['rgb']) # Lyft: CAM_FRONT_LEFT
        curr_depth_path = os.path.join(self.depth_root, anno['CAM_FRONT_RIGHT']['depth'])
        meta_data = self.load_meta_data(anno['CAM_FRONT_RIGHT'])
        ori_curr_intrinsic = meta_data['cam_in']
        
        curr_rgb, curr_depth = self.load_rgb_depth(curr_rgb_path, curr_depth_path)
        ori_h, ori_w, _ = curr_rgb.shape
        # create camera model
        curr_cam_model = self.create_cam_model(curr_rgb.shape[0], curr_rgb.shape[1], ori_curr_intrinsic)
        # load tmpl rgb info
        # tmpl_annos = self.load_tmpl_annos(anno, curr_rgb, meta_data)
        # tmpl_rgb = tmpl_annos['tmpl_rgb_list'] # list of reference rgbs

        transform_paras = dict()
        rgbs, depths, intrinsics, cam_models, other_labels, transform_paras = self.img_transforms(
                                                                   images=[curr_rgb, ], 
                                                                   labels=[curr_depth, ], 
                                                                   intrinsics=[ori_curr_intrinsic,], 
                                                                   cam_models=[curr_cam_model, ],
                                                                   transform_paras=transform_paras)
        # depth in augmented size
        # depth_out = self.clip_depth(depths[0])
        # depth in original size
        #depth_out = self.clip_depth(curr_depth)
        depth_out = curr_depth

        filename = os.path.basename(curr_rgb_path)
        curr_intrinsic_mat = self.intrinsics_list2mat(intrinsics[0])

        pad = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0]
        scale_ratio = transform_paras['label_scale_factor'] if 'label_scale_factor' in transform_paras else 1.0
        cam_models_stacks = [
            torch.nn.functional.interpolate(cam_models[0][None, :, :, :], size=(cam_models[0].shape[1]//i, cam_models[0].shape[2]//i), mode='bilinear', align_corners=False).squeeze()
            for i in [2, 4, 8, 16, 32] 
            ]
        raw_rgb = torch.from_numpy(curr_rgb)
        data = dict(input=rgbs[0],
                    target=depth_out,
                    intrinsic=curr_intrinsic_mat,
                    filename=filename,
                    dataset=self.data_name,
                    cam_model=cam_models_stacks,
                    # ref_input=rgbs[1:],
                    # tmpl_flg=tmpl_annos['w_tmpl'],
                    pad=pad,
                    scale=scale_ratio,
                    raw_rgb=raw_rgb) 
        return data


    def process_depth(self, depth):
        depth[depth>65500] = 0
        depth /= 200.0
        return depth



if __name__ == '__main__':
    from mmcv.utils import Config 
    cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
    dataset_i = ApolloscapeDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
    print(dataset_i)