jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
from typing import Tuple
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
class SceneDataLoaders:
"""
This class loads a scene dataset and pre-process it (normalization, unnormalization)
Args:
state_dim : dimension of the observed state (2 for x,y position observation)
num_steps : number of observed steps
num_steps_future : number of steps in the future
batch_size: set data loader with this batch size
data_train: training dataset
data_val: validation dataset
data_test: test dataset
num_workers: number of workers to use for data loading
"""
def __init__(
self,
state_dim: int,
num_steps: int,
num_steps_future: int,
batch_size: int,
data_train: torch.Tensor,
data_val: torch.Tensor,
data_test: torch.Tensor,
num_workers: int = 0,
):
self._batch_size = batch_size
self._num_workers = num_workers
self._state_dim = state_dim
self._num_steps = num_steps
self._num_steps_future = num_steps_future
self._setup_datasets(data_train, data_val, data_test)
def train_dataloader(self, shuffle=True, drop_last=True) -> DataLoader:
"""Setup and return training DataLoader
Returns:
DataLoader: training DataLoader
"""
data_size = self._data_train_past.shape[0]
# This is a didactic data loader that only defines minimalistic inputs.
# This dataloader adds some empty tensors and ones to match the expected format with masks and map information.
train_loader = DataLoader(
dataset=TensorDataset(
self._data_train_past,
torch.ones_like(self._data_train_past[..., 0]), # Mask past
self._data_train_fut,
torch.ones_like(self._data_train_fut[..., 0]), # Mask fut
torch.ones_like(self._data_train_fut[..., 0]), # Mask loss
torch.empty(
data_size, 1, 0, 0, device=self._data_train_past.device
), # Map
torch.empty(
data_size, 1, 0, device=self._data_train_past.device
), # Mask map
self._offset_train,
self._data_train_ego_past,
self._data_train_ego_fut,
),
batch_size=self._batch_size,
shuffle=shuffle,
drop_last=drop_last,
num_workers=self._num_workers,
)
return train_loader
def val_dataloader(self, shuffle=False, drop_last=False) -> DataLoader:
"""Setup and return validation DataLoader
Returns:
DataLoader: validation DataLoader
"""
data_size = self._data_val_past.shape[0]
# This is a didactic data loader that only defines minimalistic inputs.
# This dataloader adds some empty tensors and ones to match the expected format with masks and map information.
val_loader = DataLoader(
dataset=TensorDataset(
self._data_val_past,
torch.ones_like(self._data_val_past[..., 0]), # Mask past
self._data_val_fut,
torch.ones_like(self._data_val_fut[..., 0]), # Mask fut
torch.ones_like(self._data_val_fut[..., 0]), # Mask loss
torch.zeros(
data_size, 1, 0, 0, device=self._data_val_past.device
), # Map
torch.ones(
data_size, 1, 0, device=self._data_val_past.device
), # Mask map
self._offset_val,
self._data_val_ego_past,
self._data_val_ego_fut,
),
batch_size=self._batch_size,
shuffle=shuffle,
drop_last=drop_last,
num_workers=self._num_workers,
)
return val_loader
def test_dataloader(self) -> DataLoader:
"""Setup and return test DataLoader
Returns:
DataLoader: test DataLoader
"""
data_size = self._data_test_past.shape[0]
# This is a didactic data loader that only defines minimalistic inputs.
# This dataloader adds some empty tensors and ones to match the expected format with masks and map information.
test_loader = DataLoader(
dataset=TensorDataset(
self._data_test_past,
torch.ones_like(self._data_test_past[..., 0]), # Mask
torch.zeros(
data_size, 0, 1, 0, device=self._data_test_past.device
), # Map
torch.ones(
data_size, 0, 1, device=self._data_test_past.device
), # Mask map
self._offset_test,
self._data_test_ego_past,
self._data_test_ego_fut,
),
batch_size=self._batch_size,
shuffle=False,
num_workers=self._num_workers,
)
return test_loader
def _setup_datasets(
self, data_train: torch.Tensor, data_val: torch.Tensor, data_test: torch.Tensor
):
"""Setup datasets: normalize and split into past future
Args:
data_train: training dataset
data_val: validation dataset
data_test: test dataset
"""
data_train, data_train_ego = data_train[0], data_train[1]
data_val, data_val_ego = data_val[0], data_val[1]
data_test, data_test_ego = data_test[0], data_test[1]
data_train, self._offset_train = self.normalize_trajectory(data_train)
data_val, self._offset_val = self.normalize_trajectory(data_val)
data_test, self._offset_test = self.normalize_trajectory(data_test)
# This is a didactic data loader that only defines minimalistic inputs.
# An extra dimension is added to account for the number of agents in the scene.
# In this minimal input there is only one but the model using the data expects any number of agents.
self._data_train_past, self._data_train_fut = self.split_trajectory(data_train)
self._data_val_past, self._data_val_fut = self.split_trajectory(data_val)
self._data_test_past, self._data_test_fut = self.split_trajectory(data_test)
self._data_train_ego_past, self._data_train_ego_fut = self.split_trajectory(
data_train_ego
)
self._data_val_ego_past, self._data_val_ego_fut = self.split_trajectory(
data_val_ego
)
self._data_test_ego_past, self._data_test_ego_fut = self.split_trajectory(
data_test_ego
)
def split_trajectory(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Split input trajectory into history and future
Args:
input : (batch_size, (n_agents), num_steps + num_steps_future, state_dim) tensor of
entire trajectory [x, y]
Returns:
Tuple of history and future trajectories
"""
assert (
input.shape[-2] == self._num_steps + self._num_steps_future
), "trajectory length ({}) does not match the expected length".format(
input.shape[-2]
)
assert (
input.shape[-1] == self._state_dim
), "state dimension ({}) does no match the expected dimension".format(
input.shape[-1]
)
input_history, input_future = torch.split(
input, [self._num_steps, self._num_steps_future], dim=-2
)
return input_history, input_future
@staticmethod
def normalize_trajectory(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Normalize input trajectory by subtracting initial state
Args:
input : (some_shape, n_agents, num_steps + num_steps_future, state_dim) tensor of
entire trajectory [x, y], or (some_shape, num_steps, state_dim) tensor of history x
Returns:
Tuple of (normalized_trajectory, offset), where
normalized_trajectory has the same dimension as the input and offset is a
(some_shape, state_dim) tensor corresponding to the initial state
"""
offset = input[..., 0, :].clone()
return input - offset.unsqueeze(-2), offset
@staticmethod
def unnormalize_trajectory(
input: torch.Tensor, offset: torch.Tensor
) -> torch.Tensor:
"""Unnormalize trajectory by adding offset to input
Args:
input : (some_shape, (n_sample), num_steps_future, state_dim) tensor of future
trajectory y
offset : (some_shape, 2 or 4 or 5) tensor of offset to add to y
Returns:
Unnormalized trajectory that has the same size as input
"""
offset_dim = offset.shape[-1]
assert input.shape[-1] >= offset_dim
input_clone = input.clone()
if offset.ndim == 2:
batch_size, _ = offset.shape
assert input_clone.shape[0] == batch_size
input_clone[..., :offset_dim] = input_clone[
..., :offset_dim
] + offset.reshape(
[batch_size, *([1] * (input_clone.ndim - 2)), offset_dim]
)
elif offset.ndim == 3:
batch_size, num_agents, _ = offset.shape
assert input_clone.shape[0] == batch_size
assert input_clone.shape[1] == num_agents
input_clone[..., :offset_dim] = input_clone[
..., :offset_dim
] + offset.reshape(
[batch_size, num_agents, *([1] * (input_clone.ndim - 3)), offset_dim]
)
return input_clone