risk_biased_prediction / tests /risk_biased /models /test_interaction_encoder.py
jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
raw
history blame
8.25 kB
import os
import pytest
import torch
import torch.nn as nn
from mmcv import Config
from risk_biased.models.cvae_encoders import (
CVAEEncoder,
BiasedEncoderNN,
FutureEncoderNN,
InferenceEncoderNN,
)
from risk_biased.models.latent_distributions import (
GaussianLatentDistribution,
QuantizedDistributionCreator,
)
from risk_biased.models.cvae_params import CVAEParams
@pytest.fixture(scope="module")
def params():
torch.manual_seed(0)
working_dir = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(
working_dir, "..", "..", "..", "risk_biased", "config", "learning_config.py"
)
waymo_config_path = os.path.join(
working_dir, "..", "..", "..", "risk_biased", "config", "waymo_config.py"
)
paths = [config_path, waymo_config_path]
if isinstance(paths, str):
cfg = Config.fromfile(paths)
else:
cfg = Config.fromfile(paths[0])
for path in paths[1:]:
c = Config.fromfile(path)
cfg.update(c)
cfg.batch_size = 4
cfg.state_dim = 5
cfg.dynamic_state_dim = 5
cfg.map_state_dim = 2
cfg.num_steps = 3
cfg.num_steps_future = 4
cfg.latent_dim = 2
cfg.hidden_dim = 64
cfg.device = "cpu"
cfg.sequence_encoder_type = "LSTM"
cfg.sequence_decoder_type = "MLP"
return cfg
@pytest.mark.parametrize(
"num_agents, num_map_objects, type, interaction_nn_class",
[
(4, 5, "MLP", BiasedEncoderNN),
(2, 4, "LSTM", BiasedEncoderNN),
(3, 2, "maskedLSTM", BiasedEncoderNN),
(4, 5, "MLP", FutureEncoderNN),
(2, 4, "LSTM", FutureEncoderNN),
(3, 2, "maskedLSTM", FutureEncoderNN),
(4, 5, "MLP", InferenceEncoderNN),
(2, 4, "LSTM", InferenceEncoderNN),
(3, 2, "maskedLSTM", InferenceEncoderNN),
],
)
def test_attention_encoder_nn(
params,
num_agents: int,
num_map_objects: int,
type: str,
interaction_nn_class: nn.Module,
):
params.sequence_encoder_type = type
cvae_params = CVAEParams.from_config(params)
if interaction_nn_class == BiasedEncoderNN:
model = interaction_nn_class(
cvae_params,
num_steps=cvae_params.num_steps,
latent_dim=2 * cvae_params.latent_dim,
)
elif interaction_nn_class == FutureEncoderNN:
model = interaction_nn_class(
cvae_params,
num_steps=cvae_params.num_steps + cvae_params.num_steps_future,
latent_dim=2 * cvae_params.latent_dim,
)
else:
model = interaction_nn_class(
cvae_params,
num_steps=cvae_params.num_steps,
latent_dim=2 * cvae_params.latent_dim,
)
assert model.latent_dim == 2 * params.latent_dim
assert model.hidden_dim == params.hidden_dim
x = torch.rand(params.batch_size, num_agents, params.num_steps, params.state_dim)
offset = x[:, :, -1, :]
x = x - offset.unsqueeze(-2)
mask_x = torch.rand(params.batch_size, num_agents, params.num_steps) > 0.1
encoded_absolute = torch.rand(params.batch_size, num_agents, params.hidden_dim)
encoded_map = torch.rand(params.batch_size, num_map_objects, params.hidden_dim)
mask_map = torch.rand(params.batch_size, num_map_objects) > 0.1
if interaction_nn_class == FutureEncoderNN:
y = torch.rand(
params.batch_size, num_agents, params.num_steps_future, params.state_dim
)
y = y - offset.unsqueeze(-2)
y_ego = y[:, 0:1]
mask_y = torch.rand(params.batch_size, num_agents, params.num_steps_future)
else:
y = None
y_ego = None
mask_y = None
x_ego = x[:, 0:1]
if interaction_nn_class == BiasedEncoderNN:
risk_level = torch.rand(params.batch_size, num_agents)
else:
risk_level = None
output = model(
x,
mask_x,
encoded_absolute,
encoded_map,
mask_map,
y=y,
mask_y=mask_y,
x_ego=x_ego,
y_ego=y_ego,
offset=offset,
risk_level=risk_level,
)
# check shape
assert output.shape == (params.batch_size, num_agents, 2 * params.latent_dim)
@pytest.mark.parametrize(
"num_agents, num_map_objects, type, interaction_nn_class, latent_distribution_class",
[
(2, 8, "MLP", BiasedEncoderNN, GaussianLatentDistribution),
(7, 5, "LSTM", BiasedEncoderNN, GaussianLatentDistribution),
(2, 10, "maskedLSTM", BiasedEncoderNN, QuantizedDistributionCreator),
(2, 8, "MLP", FutureEncoderNN, GaussianLatentDistribution),
(7, 5, "LSTM", FutureEncoderNN, QuantizedDistributionCreator),
(2, 10, "maskedLSTM", FutureEncoderNN, GaussianLatentDistribution),
(2, 8, "MLP", InferenceEncoderNN, QuantizedDistributionCreator),
(7, 5, "LSTM", InferenceEncoderNN, GaussianLatentDistribution),
(2, 10, "maskedLSTM", InferenceEncoderNN, GaussianLatentDistribution),
],
)
# TODO: Add test for QuantizedDistributionCreator
def test_attention_cvae_encoder(
params,
num_agents: int,
num_map_objects: int,
type: str,
interaction_nn_class,
latent_distribution_class,
):
params.sequence_encoder_type = type
if interaction_nn_class == FutureEncoderNN:
risk_level = None
y = torch.rand(
params.batch_size, num_agents, params.num_steps_future, params.state_dim
)
mask_y = torch.rand(params.batch_size, num_agents, params.num_steps_future)
else:
risk_level = torch.rand(params.batch_size, num_agents)
y = None
mask_y = None
if interaction_nn_class == BiasedEncoderNN:
model = interaction_nn_class(
CVAEParams.from_config(params),
num_steps=params.num_steps,
latent_dim=2 * params.latent_dim,
)
elif interaction_nn_class == FutureEncoderNN:
model = interaction_nn_class(
CVAEParams.from_config(params),
num_steps=params.num_steps + params.num_steps_future,
latent_dim=2 * params.latent_dim,
)
else:
model = interaction_nn_class(
CVAEParams.from_config(params),
num_steps=params.num_steps,
latent_dim=2 * params.latent_dim,
)
encoder = CVAEEncoder(model, GaussianLatentDistribution)
# check latent_dim
assert encoder.latent_dim == 2 * params.latent_dim
x = torch.rand(params.batch_size, num_agents, params.num_steps, params.state_dim)
offset = x[:, :, -1, :]
x = x - offset.unsqueeze(-2)
if y is not None:
y = y - offset.unsqueeze(-2)
x_ego = x[:, 0:1]
y_ego = y[:, 0:1]
else:
x_ego = x[:, 0:1]
y_ego = None
mask_x = torch.rand(params.batch_size, num_agents, params.num_steps) > 0.1
encoded_absolute = torch.rand(params.batch_size, num_agents, params.hidden_dim)
encoded_map = torch.rand(params.batch_size, num_map_objects, params.hidden_dim)
mask_map = torch.rand(params.batch_size, num_map_objects) > 0.1
latent_distribution = encoder(
x=x,
mask_x=mask_x,
encoded_absolute=encoded_absolute,
encoded_map=encoded_map,
mask_map=mask_map,
y=y,
mask_y=mask_y,
x_ego=x_ego,
y_ego=y_ego,
offset=offset,
risk_level=risk_level,
)
latent_mean = latent_distribution.mu
latent_log_std = latent_distribution.logvar
# check shape
assert (
latent_mean.shape
== latent_log_std.shape
== (params.batch_size, num_agents, params.latent_dim)
)
latent_sample_1, weights = latent_distribution.sample()
# check shape when n_samples = 0
assert latent_sample_1.shape == latent_mean.shape
assert latent_sample_1.shape[:-1] == weights.shape
latent_sample_2, weights = latent_distribution.sample(n_samples=2)
# check shape when n_samples = 2
assert latent_sample_2.shape == (
params.batch_size,
num_agents,
2,
params.latent_dim,
)
latent_sample_3, weights = latent_distribution.sample()
# make sure sampling is non-deterministic
assert not torch.allclose(latent_sample_1, latent_sample_3)