Spaces:
Building
Building
File size: 5,520 Bytes
4187c6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import numpy as np
import torch
from torchvision.transforms import functional as tfn
import torchvision.transforms.functional as tvf
from ..utils import decompose_rotmat
from ..image import pad_image, rectify_image, resize_image
from ...utils.wrappers import Camera
from ..schema import KITTIDataConfiguration
class BEVTransform:
def __init__(self,
cfg: KITTIDataConfiguration, augmentations):
self.cfg = cfg
self.augmentations = augmentations
@staticmethod
def _compact_labels(msk, cat, iscrowd):
ids = np.unique(msk)
if 0 not in ids:
ids = np.concatenate((np.array([0], dtype=np.int32), ids), axis=0)
ids_to_compact = np.zeros((ids.max() + 1,), dtype=np.int32)
ids_to_compact[ids] = np.arange(0, ids.size, dtype=np.int32)
msk = ids_to_compact[msk]
cat = cat[ids]
iscrowd = iscrowd[ids]
return msk, cat, iscrowd
def __call__(self, img, bev_msk=None, bev_plabel=None, fv_msk=None, bev_weights_msk=None,
bev_cat=None, bev_iscrowd=None, fv_cat=None, fv_iscrowd=None,
fv_intrinsics=None, ego_pose=None):
# Wrap in np.array
if bev_cat is not None:
bev_cat = np.array(bev_cat, dtype=np.int32)
if bev_iscrowd is not None:
bev_iscrowd = np.array(bev_iscrowd, dtype=np.uint8)
if ego_pose is not None:
ego_pose = np.array(ego_pose, dtype=np.float32)
roll, pitch, yaw = decompose_rotmat(ego_pose[:3, :3])
# Image transformations
img = tfn.to_tensor(img)
# img = [self._normalize_image(rgb) for rgb in img]
fx = fv_intrinsics[0][0]
fy = fv_intrinsics[1][1]
cx = fv_intrinsics[0][2]
cy = fv_intrinsics[1][2]
width = img.shape[2]
height = img.shape[1]
cam = Camera(torch.tensor(
[width, height, fx, fy, cx - 0.5, cy - 0.5])).float()
if not self.cfg.gravity_align:
# Turn off gravity alignment
roll = 0.0
pitch = 0.0
img, valid = rectify_image(img, cam, roll, pitch)
else:
img, valid = rectify_image(
img, cam, roll, pitch if self.cfg.rectify_pitch else None
)
roll = 0.0
if self.cfg.rectify_pitch:
pitch = 0.0
if self.cfg.target_focal_length is not None:
# Resize to a canonical focal length
factor = self.cfg.target_focal_length / cam.f.numpy()
size = (np.array(img.shape[-2:][::-1]) * factor).astype(int)
img, _, cam, valid = resize_image(img, size, camera=cam, valid=valid)
size_out = self.cfg.resize_image
if size_out is None:
# Round the edges up such that they are multiple of a factor
stride = self.cfg.pad_to_multiple
size_out = (np.ceil((size / stride)) * stride).astype(int)
# Crop or pad such that both edges are of the given size
img, valid, cam = pad_image(
img, size_out, cam, valid, crop_and_center=False
)
elif self.cfg.resize_image is not None:
img, _, cam, valid = resize_image(
img, self.cfg.resize_image, fn=max, camera=cam, valid=valid
)
if self.cfg.pad_to_square:
# Pad such that both edges are of the given size
img, valid, cam = pad_image(img, self.cfg.resize_image, cam, valid)
# Label transformations,
if bev_msk is not None:
bev_msk = np.expand_dims(
np.array(bev_msk, dtype=np.int32, copy=False),
axis=0
)
bev_msk, bev_cat, bev_iscrowd = self._compact_labels(
bev_msk, bev_cat, bev_iscrowd
)
bev_msk = torch.from_numpy(bev_msk)
bev_cat = torch.from_numpy(bev_cat)
rotated_mask = torch.rot90(bev_msk, dims=(1, 2))
cropped_mask = rotated_mask[:, :672, (rotated_mask.size(2) - 672) // 2:-(rotated_mask.size(2) - 672) // 2]
bev_msk = cropped_mask.squeeze(0)
seg_masks = bev_cat[bev_msk]
seg_masks_onehot = seg_masks.clone()
seg_masks_onehot[seg_masks_onehot == 255] = 0
seg_masks_onehot = torch.nn.functional.one_hot(
seg_masks_onehot.to(torch.int64),
num_classes=self.cfg.num_classes
)
seg_masks_onehot[seg_masks == 255] = 0
seg_masks_onehot = seg_masks_onehot.permute(2, 0, 1)
seg_masks_down = tvf.resize(seg_masks_onehot, (100, 100))
seg_masks_down = seg_masks_down.permute(1, 2, 0)
if self.cfg.class_mapping is not None:
seg_masks_down = seg_masks_down[:, :, self.cfg.class_mapping]
img = self.augmentations(img)
flood_masks = torch.all(seg_masks_down == 0, dim=2).float()
ret = {
"image": img,
"valid": valid,
"camera": cam,
"seg_masks": (seg_masks_down).float().contiguous(),
"flood_masks": flood_masks,
"roll_pitch_yaw": torch.tensor((roll, pitch, yaw)).float(),
"confidence_map": flood_masks,
}
for key, value in ret.items():
if isinstance(value, np.ndarray):
ret[key] = torch.from_numpy(value)
return ret
|