jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
raw
history blame
9.83 kB
from typing import Tuple
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
class SceneDataLoaders:
"""
This class loads a scene dataset and pre-process it (normalization, unnormalization)
Args:
state_dim : dimension of the observed state (2 for x,y position observation)
num_steps : number of observed steps
num_steps_future : number of steps in the future
batch_size: set data loader with this batch size
data_train: training dataset
data_val: validation dataset
data_test: test dataset
num_workers: number of workers to use for data loading
"""
def __init__(
self,
state_dim: int,
num_steps: int,
num_steps_future: int,
batch_size: int,
data_train: torch.Tensor,
data_val: torch.Tensor,
data_test: torch.Tensor,
num_workers: int = 0,
):
self._batch_size = batch_size
self._num_workers = num_workers
self._state_dim = state_dim
self._num_steps = num_steps
self._num_steps_future = num_steps_future
self._setup_datasets(data_train, data_val, data_test)
def train_dataloader(self, shuffle=True, drop_last=True) -> DataLoader:
"""Setup and return training DataLoader
Returns:
DataLoader: training DataLoader
"""
data_size = self._data_train_past.shape[0]
# This is a didactic data loader that only defines minimalistic inputs.
# This dataloader adds some empty tensors and ones to match the expected format with masks and map information.
train_loader = DataLoader(
dataset=TensorDataset(
self._data_train_past,
torch.ones_like(self._data_train_past[..., 0]), # Mask past
self._data_train_fut,
torch.ones_like(self._data_train_fut[..., 0]), # Mask fut
torch.ones_like(self._data_train_fut[..., 0]), # Mask loss
torch.empty(
data_size, 1, 0, 0, device=self._data_train_past.device
), # Map
torch.empty(
data_size, 1, 0, device=self._data_train_past.device
), # Mask map
self._offset_train,
self._data_train_ego_past,
self._data_train_ego_fut,
),
batch_size=self._batch_size,
shuffle=shuffle,
drop_last=drop_last,
num_workers=self._num_workers,
)
return train_loader
def val_dataloader(self, shuffle=False, drop_last=False) -> DataLoader:
"""Setup and return validation DataLoader
Returns:
DataLoader: validation DataLoader
"""
data_size = self._data_val_past.shape[0]
# This is a didactic data loader that only defines minimalistic inputs.
# This dataloader adds some empty tensors and ones to match the expected format with masks and map information.
val_loader = DataLoader(
dataset=TensorDataset(
self._data_val_past,
torch.ones_like(self._data_val_past[..., 0]), # Mask past
self._data_val_fut,
torch.ones_like(self._data_val_fut[..., 0]), # Mask fut
torch.ones_like(self._data_val_fut[..., 0]), # Mask loss
torch.zeros(
data_size, 1, 0, 0, device=self._data_val_past.device
), # Map
torch.ones(
data_size, 1, 0, device=self._data_val_past.device
), # Mask map
self._offset_val,
self._data_val_ego_past,
self._data_val_ego_fut,
),
batch_size=self._batch_size,
shuffle=shuffle,
drop_last=drop_last,
num_workers=self._num_workers,
)
return val_loader
def test_dataloader(self) -> DataLoader:
"""Setup and return test DataLoader
Returns:
DataLoader: test DataLoader
"""
data_size = self._data_test_past.shape[0]
# This is a didactic data loader that only defines minimalistic inputs.
# This dataloader adds some empty tensors and ones to match the expected format with masks and map information.
test_loader = DataLoader(
dataset=TensorDataset(
self._data_test_past,
torch.ones_like(self._data_test_past[..., 0]), # Mask
torch.zeros(
data_size, 0, 1, 0, device=self._data_test_past.device
), # Map
torch.ones(
data_size, 0, 1, device=self._data_test_past.device
), # Mask map
self._offset_test,
self._data_test_ego_past,
self._data_test_ego_fut,
),
batch_size=self._batch_size,
shuffle=False,
num_workers=self._num_workers,
)
return test_loader
def _setup_datasets(
self, data_train: torch.Tensor, data_val: torch.Tensor, data_test: torch.Tensor
):
"""Setup datasets: normalize and split into past future
Args:
data_train: training dataset
data_val: validation dataset
data_test: test dataset
"""
data_train, data_train_ego = data_train[0], data_train[1]
data_val, data_val_ego = data_val[0], data_val[1]
data_test, data_test_ego = data_test[0], data_test[1]
data_train, self._offset_train = self.normalize_trajectory(data_train)
data_val, self._offset_val = self.normalize_trajectory(data_val)
data_test, self._offset_test = self.normalize_trajectory(data_test)
# This is a didactic data loader that only defines minimalistic inputs.
# An extra dimension is added to account for the number of agents in the scene.
# In this minimal input there is only one but the model using the data expects any number of agents.
self._data_train_past, self._data_train_fut = self.split_trajectory(data_train)
self._data_val_past, self._data_val_fut = self.split_trajectory(data_val)
self._data_test_past, self._data_test_fut = self.split_trajectory(data_test)
self._data_train_ego_past, self._data_train_ego_fut = self.split_trajectory(
data_train_ego
)
self._data_val_ego_past, self._data_val_ego_fut = self.split_trajectory(
data_val_ego
)
self._data_test_ego_past, self._data_test_ego_fut = self.split_trajectory(
data_test_ego
)
def split_trajectory(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Split input trajectory into history and future
Args:
input : (batch_size, (n_agents), num_steps + num_steps_future, state_dim) tensor of
entire trajectory [x, y]
Returns:
Tuple of history and future trajectories
"""
assert (
input.shape[-2] == self._num_steps + self._num_steps_future
), "trajectory length ({}) does not match the expected length".format(
input.shape[-2]
)
assert (
input.shape[-1] == self._state_dim
), "state dimension ({}) does no match the expected dimension".format(
input.shape[-1]
)
input_history, input_future = torch.split(
input, [self._num_steps, self._num_steps_future], dim=-2
)
return input_history, input_future
@staticmethod
def normalize_trajectory(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Normalize input trajectory by subtracting initial state
Args:
input : (some_shape, n_agents, num_steps + num_steps_future, state_dim) tensor of
entire trajectory [x, y], or (some_shape, num_steps, state_dim) tensor of history x
Returns:
Tuple of (normalized_trajectory, offset), where
normalized_trajectory has the same dimension as the input and offset is a
(some_shape, state_dim) tensor corresponding to the initial state
"""
offset = input[..., 0, :].clone()
return input - offset.unsqueeze(-2), offset
@staticmethod
def unnormalize_trajectory(
input: torch.Tensor, offset: torch.Tensor
) -> torch.Tensor:
"""Unnormalize trajectory by adding offset to input
Args:
input : (some_shape, (n_sample), num_steps_future, state_dim) tensor of future
trajectory y
offset : (some_shape, 2 or 4 or 5) tensor of offset to add to y
Returns:
Unnormalized trajectory that has the same size as input
"""
offset_dim = offset.shape[-1]
assert input.shape[-1] >= offset_dim
input_clone = input.clone()
if offset.ndim == 2:
batch_size, _ = offset.shape
assert input_clone.shape[0] == batch_size
input_clone[..., :offset_dim] = input_clone[
..., :offset_dim
] + offset.reshape(
[batch_size, *([1] * (input_clone.ndim - 2)), offset_dim]
)
elif offset.ndim == 3:
batch_size, num_agents, _ = offset.shape
assert input_clone.shape[0] == batch_size
assert input_clone.shape[1] == num_agents
input_clone[..., :offset_dim] = input_clone[
..., :offset_dim
] + offset.reshape(
[batch_size, num_agents, *([1] * (input_clone.ndim - 3)), offset_dim]
)
return input_clone