from datasets import load_dataset, Dataset import fire from functools import partial, update_wrapper import numpy import os from typing import Dict, Iterable, Tuple import sys import time import torch import gradio as gr from huggingface_hub import hf_hub_download from mmcv import Config import plotly.graph_objects as go from torch.utils.data.dataloader import DataLoader from risk_biased.utils.load_model import get_predictor from risk_biased.utils.torch_utils import load_weights from risk_biased.utils.waymo_dataloader import WaymoDataloaders from risk_biased.predictors.biased_predictor import ( LitTrajectoryPredictor, ) def to_numpy(**kwargs): dic_outputs = {} for k, v in kwargs.items(): dic_outputs[k] = v.detach().cpu().numpy() return dic_outputs def get_scatter_data(x, mask_x, name, **kwargs): return [ go.Scatter( x=x[k, mask_x[k], 0], y=x[k, mask_x[k], 1], showlegend=k == 0, name=name, **kwargs, ) for k in range(x.shape[0]) ] def configuration_paths() -> Iterable[os.PathLike]: working_dir = os.path.dirname(os.path.realpath(__file__)) return [ os.path.join( working_dir, "../../risk_biased/config", config_file, ) for config_file in ("learning_config.py", "waymo_config.py") ] def load_item(index: int, dataset: Dataset, device: str = "cpu") -> Tuple: x = torch.from_numpy(numpy.array(dataset[index]["x"]).astype(numpy.float32)).to(device) mask_x = torch.from_numpy(numpy.array(dataset[index]["mask_x"]).astype(numpy.bool8)).to(device) y = torch.from_numpy(numpy.array(dataset[index]["y"]).astype(numpy.float32)).to(device) mask_y = torch.from_numpy(numpy.array(dataset[index]["mask_y"]).astype(numpy.bool8)).to(device) mask_loss = torch.from_numpy( numpy.array(dataset[index]["mask_loss"]).astype(numpy.bool8)).to(device) map_data = torch.from_numpy(numpy.array(dataset[index]["map_data"]).astype(numpy.float32)).to(device) mask_map = torch.from_numpy(numpy.array(dataset[index]["mask_map"]).astype(numpy.bool8)).to(device) offset = torch.from_numpy(numpy.array(dataset[index]["offset"]).astype(numpy.float32)).to(device) x_ego = torch.from_numpy(numpy.array(dataset[index]["x_ego"]).astype(numpy.float32)).to(device) y_ego = torch.from_numpy(numpy.array(dataset[index]["y_ego"]).astype(numpy.float32)).to(device) return (x, mask_x, map_data, mask_map, offset, x_ego, y_ego), y, mask_y, mask_loss def build_data( predictor: LitTrajectoryPredictor, dataset: Dataset, index: int, risk_level: float, n_samples: int, ) -> Dict[str, go.Scatter]: assert n_samples >= 1 batch, y, mask_y, mask_loss = load_item(index, dataset, predictor.device) predictions = predictor.predict_step( batch=batch, risk_level=risk_level, n_samples=n_samples, ) offset = batch[4] y = predictor._unnormalize_trajectory(y, offset) x = predictor._unnormalize_trajectory(batch[0], offset) numpy_data = to_numpy( predictions=predictions, y=y, mask_y=mask_y, x=x, mask_x=batch[1], map_data=batch[2], mask_map=batch[3], mask_pred=mask_loss, ) x = numpy_data["x"][0] mask_x = numpy_data["mask_x"][0] y = numpy_data["y"][0] mask_y = numpy_data["mask_y"][0] pred = numpy_data["predictions"][0] mask_pred = numpy_data["mask_pred"][0] map_data = numpy_data["map_data"][0] mask_map = numpy_data["mask_map"][0] data_x = get_scatter_data( x, mask_x, mode="lines", line=dict(width=2, color="black"), name="Past", ) ego_present = get_scatter_data( x=x[0:1, -1:], mask_x=mask_x[0:1, -1:], mode="markers", marker=dict(color="blue", size=20, opacity=0.5), name="Ego", ) agent_present = get_scatter_data( x=x[1:2, -1:], mask_x=mask_x[1:2, -1:], mode="markers", marker=dict(color="green", size=20, opacity=0.5), name="Agent", ) data_y = get_scatter_data( y, mask_y, mode="lines", line=dict(width=2, color="green"), name="Ground truth", ) data_map = get_scatter_data( map_data, mask_map, mode="lines", line=dict(width=15, color="gray"), opacity=0.3, name="Centerline", ) data_pred = [] forecasts_end = [] for i in range(n_samples): cur_data_pred = get_scatter_data( pred[:, i], mask_pred, mode="lines", line=dict(width=2, color="red"), name="Forecast", ) data_pred += cur_data_pred forecast_end = get_scatter_data( pred[:, i, -1:], mask_pred[:, -1:], mode="markers", marker=dict(color="red", size=10, opacity=0.5, symbol="x"), name="Forecast end", ) forecasts_end += forecast_end static_data = data_map + data_x + data_y + data_pred + ego_present + agent_present + forecasts_end animation_opacity = 0.5 frames_x = [ go.Frame( data=[ go.Scatter( x=x[mask_x[:, k], k, 0], y=x[mask_x[:, k], k, 1], mode="markers", opacity=animation_opacity, marker=dict(color="black", size=15), showlegend=False, ), go.Scatter( x=x[0:1, k, 0], y=x[0:1, k, 1], mode="markers", opacity=animation_opacity, marker=dict(color="blue", size=15), showlegend=False, ), ] ) for k in range(x.shape[1]) ] frames_y_pred = [] for k in range(y.shape[1]): cur_gt_agent_data = go.Scatter( x=y[1:2][mask_y[1:2, k], k, 0], y=y[1:2][mask_y[1:2, k], k, 1], mode="markers", opacity=animation_opacity, marker=dict(color="green", size=15), ) cur_gt_future_data = go.Scatter( x=y[2:][mask_y[2:, k], k, 0], y=y[2:][mask_y[2:, k], k, 1], mode="markers", opacity=animation_opacity, marker=dict(color="black", size=15), ) cur_pred_data = [] for i in range(n_samples): cur_pred_data.append( go.Scatter( x=pred[mask_pred[:, k], i, k, 0], y=pred[mask_pred[:, k], i, k, 1], mode="markers", opacity=animation_opacity, marker=dict(color="red", size=15), showlegend=False, ) ) cur_ego_data = go.Scatter( x=y[0:1, k, 0], y=y[0:1, k, 1], mode="markers", opacity=animation_opacity, marker=dict(color="blue", size=15), ) cur_data = [cur_gt_agent_data, cur_gt_future_data, *cur_pred_data, cur_ego_data] frame = go.Frame(data=cur_data) frames_y_pred.append(frame) return {"frames": frames_x + frames_y_pred, "data": static_data} def prediction_plot( predictor: LitTrajectoryPredictor, dataset: Dataset, index: int, risk_level: float, n_samples: int = 1, use_biaser: bool = True, ) -> go.Figure: range_radius = 80 if use_biaser: risk_level = float(risk_level) else: risk_level = None layout = go.Layout( xaxis=dict( range=[-0.5*range_radius, 1.5*range_radius], autorange=False, zeroline=False, ), yaxis=dict( range=[-range_radius, range_radius], autorange=False, zeroline=False, ), title_text="Road Scene", hovermode="closest", width=600, height=300, updatemenus=[ dict( type="buttons", buttons=[ dict( label="Play", method="animate", args=[ None, dict( transition=dict(duration=100), frame=dict(duration=100, redraw=False), mode="immediate", fromcurrent=True, ), ], ), dict( label="Pause", method="animate", args=[[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate", "transition": {"duration": 0}}], ) ], ) ], ) fig = go.Figure( **build_data(predictor, dataset, index, risk_level, n_samples), layout=layout, ) fig.update_geos(projection_type="equirectangular", visible=True, resolution=110) return fig def get_figure( predictor: LitTrajectoryPredictor, dataset: Dataset, index: int, risk_level: float, n_samples: int, ) -> go.Figure: fig = prediction_plot( predictor, dataset, index, risk_level, n_samples, use_biaser=True ) fig.update_layout() return fig def update_figure( predictor: LitTrajectoryPredictor, dataset: Dataset, index: int, risk_level: float, n_samples: int, image = None ) -> go.Figure: fig = prediction_plot( predictor, dataset, index, risk_level, n_samples, use_biaser=True ) fig.update_layout() return fig def load_predictor_from_hf(model_source: str = "TRI-ML/risk_biased_model", config_name: str="learning_config.py", checkpoint_name: str = "last.ckpt", device: str = "cpu") -> Tuple[LitTrajectoryPredictor, Dataset]: config_file = hf_hub_download(model_source, filename=config_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN')) ckpt = torch.load(hf_hub_download(model_source, filename=checkpoint_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN')), map_location="cpu") cfg = Config.fromfile(config_file) predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory) predictor = load_weights(predictor, ckpt) predictor.eval() predictor = predictor.to(device) return predictor def load_dataset_from_hf(data_source: str = "jmercat/risk_biased_dataset") -> Dataset: dataset = load_dataset(data_source, split="test") return dataset def main(load_from=None, cfg_path=None): # Define the device to use device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Getting dataset") dataset = load_dataset_from_hf() if load_from is not None: cfg = Config.fromfile(cfg_path) predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory) predictor = load_weights(predictor, torch.load(load_from, map_location="cpu")) else: print("Getting model.") predictor = load_predictor_from_hf(device=device) ui_update_fn = partial(update_figure, predictor, dataset) # Do the same thing as above but using the gradio blocks API with gr.Blocks() as interface: gr.Markdown( """ # Risk-Aware Prediction Make predictions for the green agent with a risk-seeking bias towards the ego vehicle in blue. The risk level is a value between 0 and 1, where 0 is not risk-seeking and 1 is the most risk-seeking. If "Use Biased Encoder" is unchecked, the risk level is ignored and the model will make predictions without a risk-seeking bias. For more information, see the paper [RAP: Risk-Aware Prediction for Robust Planning](https://arxiv.org/abs/2210.01368) published at CoRL 2022. """) initial_index = 27 initial_n_samples = 10 image = gr.Plot(get_figure(predictor, dataset, initial_index, 0, initial_n_samples)) interface.queue() index = gr.Slider( minimum=0, maximum=len(dataset)-1, step=1, value=initial_index, label="Index", ) risk_level = gr.Slider(minimum=0, maximum=1, step=0.01, label="Risk") n_samples = gr.Slider(minimum=1, maximum=20, step=1, value=initial_n_samples, label="Num Samples") button = gr.Button(label="Re-sample") # Removed the interactive plot because it was running on the first change and all changes made during computation were ignored # This caused the plot to be out of sync with the sliders # index.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image) # risk_level.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image) # n_samples.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image) button.click(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image) interface.launch(debug=False) if __name__ == "__main__": fire.Fire(main)