Spaces:
Runtime error
Runtime error
| import os | |
| import pytest | |
| import torch | |
| import numpy as np | |
| from mmcv import Config | |
| from risk_biased.utils.cost import BaseCostTorch, TTCCostTorch, DistanceCostTorch | |
| from risk_biased.utils.cost import BaseCostNumpy, TTCCostNumpy, DistanceCostNumpy | |
| from risk_biased.utils.cost import ( | |
| CostParams, | |
| TTCCostParams, | |
| DistanceCostParams, | |
| ) | |
| def params(): | |
| torch.manual_seed(0) | |
| working_dir = os.path.dirname(os.path.realpath(__file__)) | |
| config_path = os.path.join( | |
| working_dir, "..", "..", "..", "risk_biased", "config", "learning_config.py" | |
| ) | |
| cfg = Config.fromfile(config_path) | |
| cfg.cost_scale = 1 | |
| cfg.cost_reduce = "mean" | |
| cfg.ego_length = 4 | |
| cfg.ego_width = 1.75 | |
| cfg.distance_bandwidth = 2 | |
| cfg.time_bandwidth = 2 | |
| cfg.min_velocity_diff = 0.01 | |
| return cfg | |
| def get_fake_input(batch_size, num_steps, is_torch, use_mask, num_agents=0): | |
| if num_agents <= 0: | |
| shape = [batch_size, num_steps, 2] | |
| else: | |
| shape = [batch_size, num_agents, num_steps, 2] | |
| if is_torch: | |
| x1 = torch.rand(shape) | |
| x2 = torch.rand(shape) | |
| v1 = torch.rand(shape) | |
| v2 = torch.rand(shape) | |
| if use_mask: | |
| mask = torch.rand(shape[:-1]) > 0.1 | |
| else: | |
| mask = None | |
| else: | |
| x1 = np.random.uniform(size=shape) | |
| x2 = np.random.uniform(size=shape) | |
| v1 = np.random.uniform(size=shape) | |
| v2 = np.random.uniform(size=shape) | |
| if use_mask: | |
| mask = np.random.uniform(size=shape[:-1]) > 0.1 | |
| else: | |
| mask = None | |
| return x1, x2, v1, v2, mask | |
| def test_base_cost( | |
| params, | |
| reduce: str, | |
| batch_size: int, | |
| num_steps: int, | |
| is_torch: bool, | |
| use_mask: bool, | |
| num_agents: int, | |
| ): | |
| params.cost_reduce = reduce | |
| cost_params = CostParams.from_config(params) | |
| if is_torch: | |
| base_cost = BaseCostTorch(cost_params) | |
| else: | |
| base_cost = BaseCostNumpy(cost_params) | |
| x1, x2, v1, v2, mask = get_fake_input( | |
| batch_size, num_steps, is_torch, use_mask, num_agents | |
| ) | |
| cost, _ = base_cost(x1, x2, v1, v2, mask) | |
| if num_agents > 0: | |
| assert cost.shape == ( | |
| batch_size, | |
| num_agents, | |
| ) | |
| else: | |
| assert cost.shape == (batch_size,) | |
| assert (cost == 0).all() | |
| assert base_cost.scale == params.cost_scale | |
| assert base_cost.distance_bandwidth == 1 | |
| assert base_cost.time_bandwidth == 1 | |
| def test_generic_cost( | |
| params, | |
| param_class, | |
| cost_class, | |
| reduce: str, | |
| batch_size: int, | |
| num_steps: int, | |
| is_torch: bool, | |
| use_mask: bool, | |
| num_agents: int, | |
| ): | |
| params.cost_reduce = reduce | |
| cost_params = param_class.from_config(params) | |
| x1, x2, v1, v2, mask = get_fake_input( | |
| batch_size, num_steps, is_torch, use_mask, num_agents | |
| ) | |
| compute_cost = cost_class(cost_params) | |
| cost, _ = compute_cost(x1, x2, v1, v2, mask) | |
| # Shaped is reduced | |
| if num_agents > 0: | |
| assert cost.shape == (batch_size, num_agents) | |
| else: | |
| assert cost.shape == (batch_size,) | |
| assert (cost != 0).any() | |
| assert compute_cost.scale == params.cost_scale | |
| # Rescale the cost for comparison | |
| compute_cost.scale = params.cost_scale + 10 | |
| assert compute_cost.scale != params.cost_scale | |
| rescaled_cost, _ = compute_cost(x1, x2, v1, v2, mask) | |
| # all rescaled cost are larger but 0 cost is equal to rescaled cost | |
| assert (rescaled_cost >= cost).all() | |
| # at least some rescaled cost are strictly larger than normal scale cost | |
| assert (rescaled_cost > cost).any() | |
| # Compute mean and min costs to compare | |
| params.cost_reduce = "mean" | |
| cost_params_mean = param_class.from_config(params) | |
| cost_function_mean = cost_class(cost_params_mean) | |
| cost_mean, _ = cost_function_mean(x1, x2, v1, v2) | |
| params.cost_reduce = "min" | |
| cost_params_min = param_class.from_config(params) | |
| cost_function_min = cost_class(cost_params_min) | |
| cost_min, _ = cost_function_min(x1, x2, v1, v2) | |
| # max reduce is larger than mean | |
| if reduce == "max": | |
| assert (cost >= cost_mean).all() | |
| # min reduce is lower than any othir | |
| assert (cost_mean >= cost_min).all() | |
| assert (cost >= cost_min).all() | |