File size: 5,520 Bytes
4187c6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import numpy as np
import torch
from torchvision.transforms import functional as tfn
import torchvision.transforms.functional as tvf

from ..utils import decompose_rotmat
from ..image import pad_image, rectify_image, resize_image
from ...utils.wrappers import Camera
from ..schema import KITTIDataConfiguration


class BEVTransform:
    def __init__(self,
                 cfg: KITTIDataConfiguration, augmentations):
        self.cfg = cfg
        self.augmentations = augmentations

    @staticmethod
    def _compact_labels(msk, cat, iscrowd):
        ids = np.unique(msk)
        if 0 not in ids:
            ids = np.concatenate((np.array([0], dtype=np.int32), ids), axis=0)

        ids_to_compact = np.zeros((ids.max() + 1,), dtype=np.int32)
        ids_to_compact[ids] = np.arange(0, ids.size, dtype=np.int32)

        msk = ids_to_compact[msk]
        cat = cat[ids]
        iscrowd = iscrowd[ids]

        return msk, cat, iscrowd

    def __call__(self, img, bev_msk=None, bev_plabel=None, fv_msk=None, bev_weights_msk=None,
                 bev_cat=None, bev_iscrowd=None, fv_cat=None, fv_iscrowd=None,
                 fv_intrinsics=None, ego_pose=None):
        # Wrap in np.array
        if bev_cat is not None:
            bev_cat = np.array(bev_cat, dtype=np.int32)
        if bev_iscrowd is not None:
            bev_iscrowd = np.array(bev_iscrowd, dtype=np.uint8)

        if ego_pose is not None:
            ego_pose = np.array(ego_pose, dtype=np.float32)

        roll, pitch, yaw = decompose_rotmat(ego_pose[:3, :3])

        # Image transformations
        img = tfn.to_tensor(img)
        # img = [self._normalize_image(rgb) for rgb in img]
        fx = fv_intrinsics[0][0]
        fy = fv_intrinsics[1][1]
        cx = fv_intrinsics[0][2]
        cy = fv_intrinsics[1][2]
        width = img.shape[2]
        height = img.shape[1]

        cam = Camera(torch.tensor(
                        [width, height, fx, fy, cx - 0.5, cy - 0.5])).float()

        if not self.cfg.gravity_align:
            # Turn off gravity alignment
            roll = 0.0
            pitch = 0.0
            img, valid = rectify_image(img, cam, roll, pitch)
        else:
            img, valid = rectify_image(
                img, cam, roll, pitch if self.cfg.rectify_pitch else None
            )
            roll = 0.0
            if self.cfg.rectify_pitch:
                pitch = 0.0

        if self.cfg.target_focal_length is not None:
            # Resize to a canonical focal length
            factor = self.cfg.target_focal_length / cam.f.numpy()
            size = (np.array(img.shape[-2:][::-1]) * factor).astype(int)
            img, _, cam, valid = resize_image(img, size, camera=cam, valid=valid)
            size_out = self.cfg.resize_image
            if size_out is None:
                # Round the edges up such that they are multiple of a factor
                stride = self.cfg.pad_to_multiple
                size_out = (np.ceil((size / stride)) * stride).astype(int)
            # Crop or pad such that both edges are of the given size
            img, valid, cam = pad_image(
                img, size_out, cam, valid, crop_and_center=False
            )
        elif self.cfg.resize_image is not None:
            img, _, cam, valid = resize_image(
                img, self.cfg.resize_image, fn=max, camera=cam, valid=valid
            )
            if self.cfg.pad_to_square:
                # Pad such that both edges are of the given size
                img, valid, cam = pad_image(img, self.cfg.resize_image, cam, valid)

        # Label transformations,
        if bev_msk is not None:
            bev_msk = np.expand_dims(
                np.array(bev_msk, dtype=np.int32, copy=False),
                axis=0
            )
            bev_msk, bev_cat, bev_iscrowd = self._compact_labels(
                bev_msk, bev_cat, bev_iscrowd
            )

            bev_msk = torch.from_numpy(bev_msk)
            bev_cat = torch.from_numpy(bev_cat)

            rotated_mask = torch.rot90(bev_msk, dims=(1, 2))
            cropped_mask = rotated_mask[:, :672, (rotated_mask.size(2) - 672) // 2:-(rotated_mask.size(2) - 672) // 2]

            bev_msk = cropped_mask.squeeze(0)
            seg_masks = bev_cat[bev_msk]

            seg_masks_onehot = seg_masks.clone()
            seg_masks_onehot[seg_masks_onehot == 255] = 0
            seg_masks_onehot = torch.nn.functional.one_hot(
                seg_masks_onehot.to(torch.int64),
                num_classes=self.cfg.num_classes
            )
            seg_masks_onehot[seg_masks == 255] = 0

            seg_masks_onehot = seg_masks_onehot.permute(2, 0, 1)

            seg_masks_down = tvf.resize(seg_masks_onehot, (100, 100))

            seg_masks_down = seg_masks_down.permute(1, 2, 0)

            if self.cfg.class_mapping is not None:
                seg_masks_down = seg_masks_down[:, :, self.cfg.class_mapping]

        img = self.augmentations(img)
        flood_masks = torch.all(seg_masks_down == 0, dim=2).float()


        ret = {
            "image": img,
            "valid": valid,
            "camera": cam,
            "seg_masks": (seg_masks_down).float().contiguous(),
            "flood_masks": flood_masks,
            "roll_pitch_yaw": torch.tensor((roll, pitch, yaw)).float(),
            "confidence_map": flood_masks,
        }
        
        for key, value in ret.items():
            if isinstance(value, np.ndarray):
                ret[key] = torch.from_numpy(value)
    
        return ret