Spaces:
Building
Building
import os | |
import torch | |
import numpy as np | |
from pyquaternion import Quaternion | |
from nuscenes.nuscenes import NuScenes | |
from itertools import chain | |
from PIL import Image | |
from torchvision import transforms as T | |
import torchvision.transforms as tvf | |
from torchvision.transforms.functional import to_tensor | |
from .splits_roddick import create_splits_scenes_roddick | |
from ..image import pad_image, rectify_image, resize_image | |
from .utils import decode_binary_labels | |
from ..utils import decompose_rotmat | |
from ...utils.io import read_image | |
from ...utils.wrappers import Camera | |
from ..schema import NuScenesDataConfiguration | |
class NuScenesDataset(torch.utils.data.Dataset): | |
def __init__(self, cfg: NuScenesDataConfiguration, split="train"): | |
self.cfg = cfg | |
self.nusc = NuScenes(version=cfg.version, dataroot=str(cfg.data_dir)) | |
self.map_data_root = cfg.map_dir | |
self.split = split | |
self.scenes = create_splits_scenes_roddick() # custom based on Roddick et al. | |
scene_split = { | |
'v1.0-trainval': {'train': 'train', 'val': 'val', 'test': 'val'}, | |
'v1.0-mini': {'train': 'mini_train', 'val': 'mini_val'}, | |
}[cfg.version][split] | |
self.scenes = self.scenes[scene_split] | |
self.sample = list(filter(lambda sample: self.nusc.get( | |
'scene', sample['scene_token'])['name'] in self.scenes, self.nusc.sample)) | |
self.tfs = self.get_augmentations() if split == "train" else T.Compose([]) | |
data_tokens = [] | |
for sample in self.sample: | |
data_token = sample['data'] | |
data_token = [v for k,v in data_token.items() if k == "CAM_FRONT"] | |
data_tokens.append(data_token) | |
data_tokens = list(chain.from_iterable(data_tokens)) | |
data = [self.nusc.get('sample_data', token) for token in data_tokens] | |
self.data = [] | |
for d in data: | |
sample = self.nusc.get('sample', d['sample_token']) | |
scene = self.nusc.get('scene', sample['scene_token']) | |
location = self.nusc.get('log', scene['log_token'])['location'] | |
file_name = d['filename'] | |
ego_pose = self.nusc.get('ego_pose', d['ego_pose_token']) | |
calibrated_sensor = self.nusc.get( | |
"calibrated_sensor", d['calibrated_sensor_token']) | |
ego2global = np.eye(4).astype(np.float32) | |
ego2global[:3, :3] = Quaternion(ego_pose['rotation']).rotation_matrix | |
ego2global[:3, 3] = ego_pose['translation'] | |
sensor2ego = np.eye(4).astype(np.float32) | |
sensor2ego[:3, :3] = Quaternion( | |
calibrated_sensor['rotation']).rotation_matrix | |
sensor2ego[:3, 3] = calibrated_sensor['translation'] | |
sensor2global = ego2global @ sensor2ego | |
rotation = sensor2global[:3, :3] | |
roll, pitch, yaw = decompose_rotmat(rotation) | |
fx = calibrated_sensor['camera_intrinsic'][0][0] | |
fy = calibrated_sensor['camera_intrinsic'][1][1] | |
cx = calibrated_sensor['camera_intrinsic'][0][2] | |
cy = calibrated_sensor['camera_intrinsic'][1][2] | |
width = d['width'] | |
height = d['height'] | |
cam = Camera(torch.tensor( | |
[width, height, fx, fy, cx - 0.5, cy - 0.5])).float() | |
self.data.append({ | |
'filename': file_name, | |
'yaw': yaw, | |
'pitch': pitch, | |
'roll': roll, | |
'cam': cam, | |
'sensor2global': sensor2global, | |
'token': d['token'], | |
'sample_token': d['sample_token'], | |
'location': location | |
}) | |
if self.cfg.percentage < 1.0 and split == "train": | |
self.data = self.data[:int(len(self.data) * self.cfg.percentage)] | |
def get_augmentations(self): | |
print(f"Augmentation!", "\n" * 10) | |
augmentations = [ | |
tvf.ColorJitter( | |
brightness=self.cfg.augmentations.brightness, | |
contrast=self.cfg.augmentations.contrast, | |
saturation=self.cfg.augmentations.saturation, | |
hue=self.cfg.augmentations.hue, | |
) | |
] | |
if self.cfg.augmentations.random_resized_crop: | |
augmentations.append( | |
tvf.RandomResizedCrop(scale=(0.8, 1.0)) | |
) # RandomResizedCrop | |
if self.cfg.augmentations.gaussian_noise.enabled: | |
augmentations.append( | |
tvf.GaussianNoise( | |
mean=self.cfg.augmentations.gaussian_noise.mean, | |
std=self.cfg.augmentations.gaussian_noise.std, | |
) | |
) # Gaussian noise | |
if self.cfg.augmentations.brightness_contrast.enabled: | |
augmentations.append( | |
tvf.ColorJitter( | |
brightness=self.cfg.augmentations.brightness_contrast.brightness_factor, | |
contrast=self.cfg.augmentations.brightness_contrast.contrast_factor, | |
saturation=0, # Keep saturation at 0 for brightness and contrast adjustment | |
hue=0, | |
) | |
) # Brightness and contrast adjustment | |
return tvf.Compose(augmentations) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
d = self.data[idx] | |
image = read_image(os.path.join(self.nusc.dataroot, d['filename'])) | |
image = np.array(image) | |
cam = d['cam'] | |
roll = d['roll'] | |
pitch = d['pitch'] | |
yaw = d['yaw'] | |
with Image.open(self.map_data_root / f"{d['token']}.png") as semantic_image: | |
semantic_mask = to_tensor(semantic_image) | |
semantic_mask = decode_binary_labels(semantic_mask, self.cfg.num_classes + 1) | |
semantic_mask = torch.nn.functional.max_pool2d(semantic_mask.float(), (2, 2), stride=2) # 2 times downsample | |
semantic_mask = semantic_mask.permute(1, 2, 0) | |
semantic_mask = torch.flip(semantic_mask, [0]) | |
visibility_mask = semantic_mask[..., -1] | |
semantic_mask = semantic_mask[..., :-1] | |
if self.cfg.class_mapping is not None: | |
semantic_mask = semantic_mask[..., self.cfg.class_mapping] | |
image = ( | |
torch.from_numpy(np.ascontiguousarray(image)) | |
.permute(2, 0, 1) | |
.float() | |
.div_(255) | |
) | |
if not self.cfg.gravity_align: | |
# Turn off gravity alignment | |
roll = 0.0 | |
pitch = 0.0 | |
image, valid = rectify_image(image, cam, roll, pitch) | |
else: | |
image, valid = rectify_image( | |
image, cam, roll, pitch if self.cfg.rectify_pitch else None | |
) | |
roll = 0.0 | |
if self.cfg.rectify_pitch: | |
pitch = 0.0 | |
if self.cfg.resize_image is not None: | |
image, _, cam, valid = resize_image( | |
image, self.cfg.resize_image, fn=max, camera=cam, valid=valid | |
) | |
if self.cfg.pad_to_square: | |
image, valid, cam = pad_image(image, self.cfg.resize_image, cam, valid) | |
image = self.tfs(image) | |
confidence_map = visibility_mask.clone().float() | |
confidence_map = (confidence_map - confidence_map.min()) / (confidence_map.max() - confidence_map.min()) | |
return { | |
"image": image, | |
"roll_pitch_yaw": torch.tensor([roll, pitch, yaw]).float(), | |
"camera": cam, | |
"valid": valid, | |
"seg_masks": semantic_mask.float(), | |
"token": d['token'], | |
"sample_token": d['sample_token'], | |
'location': d['location'], | |
'flood_masks': visibility_mask.float(), | |
"confidence_map": confidence_map, | |
'name': d['sample_token'] | |
} | |