File size: 3,682 Bytes
3ef1661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import json
import torch
import torchvision.transforms as transforms
import os.path
import numpy as np
import cv2
from torch.utils.data import Dataset
import random
from .__base_dataset__ import BaseDataset


class IBIMSDataset(BaseDataset):
    def __init__(self, cfg, phase, **kwargs):
        super(IBIMSDataset, self).__init__(
            cfg=cfg,
            phase=phase,
            **kwargs)
        self.metric_scale = cfg.metric_scale

        self.avg = torch.nn.AvgPool2d(kernel_size=7, stride=1, ceil_mode=False, count_include_pad=True, divisor_override=None)
        self.unfold = torch.nn.Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
        self.pad = torch.nn.ZeroPad2d(3)
    
    
    def process_depth(self, depth, rgb):
        depth[depth>50000] = 0
        depth /= self.metric_scale
        return depth

    def load_batch(self, meta_data, data_path):
        curr_intrinsic = meta_data['cam_in']
        # load rgb/depth
        curr_rgb, curr_depth = self.load_rgb_depth(data_path['rgb_path'], data_path['depth_path'])
        # get semantic labels
        curr_sem = self.load_sem_label(data_path['sem_path'], curr_depth)
        # create camera model
        curr_cam_model = self.create_cam_model(curr_rgb.shape[0], curr_rgb.shape[1], curr_intrinsic)       
        # get normal labels
        curr_normal = self.load_norm_label(data_path['normal_path'], H=curr_rgb.shape[0], W=curr_rgb.shape[1], depth=curr_depth, K=curr_intrinsic) # !!! this is diff of BaseDataset
        # get depth mask
        depth_mask = self.load_depth_valid_mask(data_path['depth_mask_path'])
        curr_depth[~depth_mask] = -1
        data_batch = dict(
            curr_rgb = curr_rgb,
            curr_depth = curr_depth,
            curr_sem = curr_sem,
            curr_normal = curr_normal,
            curr_cam_model=curr_cam_model,
        )
        return data_batch

    def load_norm_label(self, norm_path, H, W, depth, K):
        depth = torch.from_numpy(depth).squeeze()
        K = torch.Tensor([[K[0], 0 ,K[2]], 
                      [0, K[1], K[3]], 
                      [0, 0, 1]])
        K_inv = K.inverse()

        y, x = torch.meshgrid([torch.arange(0, 480, dtype=torch.float32),
                            torch.arange(0, 640, dtype=torch.float32)], indexing='ij')
        x = x.reshape(1, 480*640)
        y = y.reshape(1, 480*640)
        ones = torch.ones_like(x)
        coord_2d = torch.cat((x, y, ones), dim=0)

        coord_3d = torch.matmul(K_inv, coord_2d).view(3, 480, 640)
        coord_3d = (coord_3d * depth[None, :])[None, :]
        coord_3d_mean = self.avg(coord_3d)

        uf_coord_3d = self.unfold(coord_3d.permute(1, 0, 2, 3))
        coord_3d_decenter = uf_coord_3d - coord_3d_mean.view(3, 1, (480-6)*(640-6))
        coord_3d_decenter = coord_3d_decenter.permute(2, 0, 1)
        cov = torch.bmm(coord_3d_decenter, coord_3d_decenter.permute(0, 2, 1))
        
        eig = torch.linalg.eigh(cov)
        #svd = torch.linalg.svd(coord_3d_decenter)
        normal = (eig[1])[:, :, 0].float()
        #normal = (svd[1])[:, 2, :]
        normal = self.pad(normal.permute(1, 0).view(1, 3, (480-6), (640-6)))
        
        orient_mask = (torch.sum(normal * coord_3d, axis=1) < 0).unsqueeze(1)
        normal = normal * orient_mask - normal * (~orient_mask)
        gt_normal = normal.squeeze().permute(1, 2, 0).numpy()
        return gt_normal

if __name__ == '__main__':
    from mmcv.utils import Config 
    cfg = Config.fromfile('mono/configs/Apolloscape_DDAD/convnext_base.cascade.1m.sgd.mae.py')
    dataset_i = IBIMSDataset(cfg['Apolloscape'], 'train', **cfg.data_basic)
    print(dataset_i)