jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
import os
import pytest
import torch
from mmcv import Config
from risk_biased.mpc_planner.dynamics import PositionVelocityDoubleIntegrator
from risk_biased.mpc_planner.planner_cost import TrackingCost, TrackingCostParams
from risk_biased.mpc_planner.solver import CrossEntropySolver, CrossEntropySolverParams
from risk_biased.predictors.biased_predictor import (
LitTrajectoryPredictorParams,
LitTrajectoryPredictor,
)
from risk_biased.scene_dataset.loaders import SceneDataLoaders
from risk_biased.utils.cost import TTCCostTorch, TTCCostParams
from risk_biased.utils.risk import get_risk_estimator
from risk_biased.utils.planner_utils import to_state
@pytest.fixture(scope="module")
def params():
torch.manual_seed(0)
working_dir = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(
working_dir, "..", "..", "..", "risk_biased", "config", "learning_config.py"
)
planning_config_path = os.path.join(
working_dir, "..", "..", "..", "risk_biased", "config", "planning_config.py"
)
paths = [config_path, planning_config_path]
if isinstance(paths, str):
cfg = Config.fromfile(paths)
else:
cfg = Config.fromfile(paths[0])
for path in paths[1:]:
c = Config.fromfile(path)
cfg.update(c)
cfg.num_control_samples = 10
cfg.num_elite = 3
cfg.iter_max = 3
cfg.smoothing_factor = 0.2
cfg.mean_warm_start = True
cfg.num_steps = 3
cfg.num_steps_future = 5
cfg.state_dim = 5
cfg.dynamic_state_dim = 5
cfg.map_state_dim = 2
cfg.max_size_lane = 2
cfg.latent_dim = 2
cfg.hidden_dim = 64
cfg.num_hidden_layers = 3
return cfg
class TestCrossEntropySolver:
@pytest.fixture(autouse=True)
def setup(self, params):
self.solver_params = CrossEntropySolverParams.from_config(params)
self.dynamics_model = PositionVelocityDoubleIntegrator(params.dt)
self.interaction_cost_function = TTCCostTorch(TTCCostParams.from_config(params))
self.tracking_cost_function = TrackingCost(
TrackingCostParams.from_config(params)
)
self.risk_estimator = get_risk_estimator(params.risk_estimator)
self.control_input_mean_default = torch.randn(
1, params.num_steps_future, self.dynamics_model.control_dim
)
self.control_input_std_default = torch.rand_like(
self.control_input_mean_default
)
self.solver_default = CrossEntropySolver(
self.solver_params,
self.dynamics_model,
self.control_input_mean_default,
self.control_input_std_default,
self.tracking_cost_function,
self.interaction_cost_function,
self.risk_estimator,
)
predictor_params = LitTrajectoryPredictorParams.from_config(params)
self.predictor = LitTrajectoryPredictor(
predictor_params,
TTCCostParams.from_config(params),
SceneDataLoaders.unnormalize_trajectory,
)
self.normalizer = SceneDataLoaders.normalize_trajectory
def test_reset(self):
self.solver_default.reset()
assert self.solver_default._iter_current == 0
assert torch.allclose(
self.solver_default._control_input_mean, self.control_input_mean_default
)
assert torch.allclose(
self.solver_default._control_input_std, self.control_input_std_default
)
assert self.solver_default._latest_ado_position_future_samples == None
def test_get_elites(self, params):
control_input = torch.randn(
params.num_control_samples,
1,
params.num_steps_future,
self.dynamics_model.control_dim,
)
risk = torch.Tensor(
[0.0, 1.0, 0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6]
).unsqueeze(-1)
elite_control_input, elite_risk = self.solver_default._get_elites(
control_input, risk
)
assert elite_control_input.shape == torch.Size(
[
params.num_elite,
1,
params.num_steps_future,
self.dynamics_model.control_dim,
]
)
assert elite_risk.shape == torch.Size([params.num_elite, 1])
assert torch.allclose(elite_control_input, control_input[[0, 2, 4]])
assert torch.allclose(elite_risk, torch.Tensor([0.0, 0.1, 0.2]).unsqueeze(-1))
@pytest.mark.parametrize(
"num_elite, smoothing_factor", ([3, 0.0], [3, 1.0], [1, 0.0], [1, 1.0])
)
def test_update_control_distribution(self, params, num_elite, smoothing_factor):
solver = CrossEntropySolver(
self.solver_params,
self.dynamics_model,
self.control_input_mean_default,
self.control_input_std_default,
self.tracking_cost_function,
self.interaction_cost_function,
self.risk_estimator,
)
solver.params.num_elite = num_elite
solver.params.smoothing_factor = smoothing_factor
elite_control_input = torch.ones(
num_elite, params.num_steps_future, self.dynamics_model.control_dim
)
solver._update_control_distribution(elite_control_input)
if smoothing_factor == 0.0:
assert torch.allclose(
solver._control_input_mean, torch.ones_like(solver._control_input_mean)
)
assert torch.allclose(
solver._control_input_std, torch.zeros_like(solver._control_input_std)
)
else:
assert torch.allclose(
solver._control_input_mean, solver.control_input_mean_init
)
assert torch.allclose(
solver._control_input_std, solver.control_input_std_init
)
@pytest.mark.parametrize(
"risk_level, num_prediction_samples",
[(0.0, 1), (0.0, 10), (0.5, 1), (0.5, 10)],
)
def test_sample_prediction(self, params, risk_level, num_prediction_samples):
num_agents = 1
ado_state_history = to_state(
torch.randn(num_agents, params.num_steps, params.state_dim), params.dt
)
ego_state_history = to_state(
torch.randn(1, params.num_steps, params.state_dim), params.dt
)
ego_state_future = to_state(
torch.randn(1, params.num_steps_future, params.state_dim), params.dt
)
ado_position_future_samples, weights = CrossEntropySolver.sample_prediction(
self.predictor,
ado_state_history,
self.normalizer,
ego_state_history,
ego_state_future,
num_prediction_samples=num_prediction_samples,
risk_level=risk_level,
)
assert ado_position_future_samples.shape == torch.Size(
[num_prediction_samples, num_agents, params.num_steps_future]
)
@pytest.mark.parametrize(
"mean_warm_start, risk_level, resample_prediction, risk_in_predictor",
[
(False, 0.0, False, False),
(False, 0.0, False, True),
(False, 0.0, True, False),
(False, 0.0, True, True),
(False, 0.5, False, False),
(False, 0.5, False, True),
(False, 0.5, True, False),
(False, 0.5, True, True),
(True, 0.0, False, False),
(True, 0.0, False, True),
(True, 0.0, True, False),
(True, 0.0, True, True),
(True, 0.5, False, False),
(True, 0.5, False, True),
(True, 0.5, True, False),
(True, 0.5, True, True),
],
)
def test_solve(
self,
params,
mean_warm_start,
risk_level,
resample_prediction,
risk_in_predictor,
):
num_prediction_samples = 5
num_agents = 1
ego_state_history = to_state(
torch.randn(num_agents, params.num_steps, params.state_dim), params.dt
)
ego_state_target_trajectory = to_state(
torch.randn(num_agents, params.num_steps_future, params.state_dim),
params.dt,
)
ado_state_history = to_state(
torch.randn(num_agents, params.num_steps, 2), params.dt
)
self.solver_default.params.mean_warm_start = mean_warm_start
self.solver_default.solve(
self.predictor,
ego_state_history,
ego_state_target_trajectory,
ado_state_history,
self.normalizer,
num_prediction_samples=num_prediction_samples,
risk_level=risk_level,
resample_prediction=resample_prediction,
risk_in_predictor=risk_in_predictor,
)
assert self.solver_default._iter_current == params.iter_max
assert self.solver_default.fetch_latest_prediction().shape == torch.Size(
[num_prediction_samples, num_agents, params.num_steps_future]
)
if not mean_warm_start:
assert torch.allclose(
self.solver_default.control_input_mean_init,
self.control_input_mean_default,
)
assert torch.allclose(
self.solver_default.control_input_std_init,
self.control_input_std_default,
)
else:
assert torch.allclose(
self.solver_default.control_input_mean_init[:, -1],
self.control_input_mean_default[:, -1],
)
assert torch.allclose(
self.solver_default.control_input_mean_init[:, :-1],
self.solver_default._control_input_mean[:, 1:],
)
assert torch.allclose(
self.solver_default.control_input_std_init,
self.control_input_std_default,
)