Spaces:
Running
Running
from datasets import load_dataset, Dataset | |
import fire | |
from functools import partial, update_wrapper | |
import numpy | |
import os | |
from typing import Dict, Iterable, Tuple | |
import sys | |
import time | |
import torch | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
from mmcv import Config | |
import plotly.graph_objects as go | |
from torch.utils.data.dataloader import DataLoader | |
from risk_biased.utils.load_model import get_predictor | |
from risk_biased.utils.torch_utils import load_weights | |
from risk_biased.utils.waymo_dataloader import WaymoDataloaders | |
from risk_biased.predictors.biased_predictor import ( | |
LitTrajectoryPredictor, | |
) | |
def to_numpy(**kwargs): | |
dic_outputs = {} | |
for k, v in kwargs.items(): | |
dic_outputs[k] = v.detach().cpu().numpy() | |
return dic_outputs | |
def get_scatter_data(x, mask_x, name, **kwargs): | |
return [ | |
go.Scatter( | |
x=x[k, mask_x[k], 0], | |
y=x[k, mask_x[k], 1], | |
showlegend=k == 0, | |
name=name, | |
**kwargs, | |
) | |
for k in range(x.shape[0]) | |
] | |
def configuration_paths() -> Iterable[os.PathLike]: | |
working_dir = os.path.dirname(os.path.realpath(__file__)) | |
return [ | |
os.path.join( | |
working_dir, | |
"../../risk_biased/config", | |
config_file, | |
) | |
for config_file in ("learning_config.py", "waymo_config.py") | |
] | |
def load_item(index: int, dataset: Dataset, device: str = "cpu") -> Tuple: | |
x = torch.from_numpy(numpy.array(dataset[index]["x"]).astype(numpy.float32)).to(device) | |
mask_x = torch.from_numpy(numpy.array(dataset[index]["mask_x"]).astype(numpy.bool_)).to(device) | |
y = torch.from_numpy(numpy.array(dataset[index]["y"]).astype(numpy.float32)).to(device) | |
mask_y = torch.from_numpy(numpy.array(dataset[index]["mask_y"]).astype(numpy.bool_)).to(device) | |
mask_loss = torch.from_numpy( numpy.array(dataset[index]["mask_loss"]).astype(numpy.bool_)).to(device) | |
map_data = torch.from_numpy(numpy.array(dataset[index]["map_data"]).astype(numpy.float32)).to(device) | |
mask_map = torch.from_numpy(numpy.array(dataset[index]["mask_map"]).astype(numpy.bool_)).to(device) | |
offset = torch.from_numpy(numpy.array(dataset[index]["offset"]).astype(numpy.float32)).to(device) | |
x_ego = torch.from_numpy(numpy.array(dataset[index]["x_ego"]).astype(numpy.float32)).to(device) | |
y_ego = torch.from_numpy(numpy.array(dataset[index]["y_ego"]).astype(numpy.float32)).to(device) | |
return (x, mask_x, map_data, mask_map, offset, x_ego, y_ego), y, mask_y, mask_loss | |
def build_data( | |
predictor: LitTrajectoryPredictor, | |
dataset: Dataset, | |
index: int, | |
risk_level: float, | |
n_samples: int, | |
) -> Dict[str, go.Scatter]: | |
assert n_samples >= 1 | |
batch, y, mask_y, mask_loss = load_item(index, dataset, predictor.device) | |
predictions = predictor.predict_step( | |
batch=batch, | |
risk_level=risk_level, | |
n_samples=n_samples, | |
) | |
offset = batch[4] | |
y = predictor._unnormalize_trajectory(y, offset) | |
x = predictor._unnormalize_trajectory(batch[0], offset) | |
numpy_data = to_numpy( | |
predictions=predictions, | |
y=y, | |
mask_y=mask_y, | |
x=x, | |
mask_x=batch[1], | |
map_data=batch[2], | |
mask_map=batch[3], | |
mask_pred=mask_loss, | |
) | |
x = numpy_data["x"][0] | |
mask_x = numpy_data["mask_x"][0] | |
y = numpy_data["y"][0] | |
mask_y = numpy_data["mask_y"][0] | |
pred = numpy_data["predictions"][0] | |
mask_pred = numpy_data["mask_pred"][0] | |
map_data = numpy_data["map_data"][0] | |
mask_map = numpy_data["mask_map"][0] | |
marker_size = 12 | |
data_x = get_scatter_data( | |
x, | |
mask_x, | |
mode="lines", | |
line=dict(width=2, color="black"), | |
name="Past", | |
) | |
ego_present = get_scatter_data( | |
x=x[0:1, -1:], | |
mask_x=mask_x[0:1, -1:], | |
mode="markers", | |
marker=dict(color="blue", size=marker_size, opacity=0.5), | |
name="Ego", | |
) | |
agent_present = get_scatter_data( | |
x=x[1:2, -1:], | |
mask_x=mask_x[1:2, -1:], | |
mode="markers", | |
marker=dict(color="green", size=marker_size, opacity=0.5), | |
name="Agent", | |
) | |
data_y = get_scatter_data( | |
y, | |
mask_y, | |
mode="lines", | |
line=dict(width=2, color="green"), | |
name="Ground truth", | |
) | |
data_map = get_scatter_data( | |
map_data, | |
mask_map, | |
mode="lines", | |
line=dict(width=15, color="gray"), | |
opacity=0.3, | |
name="Centerline", | |
) | |
data_pred = [] | |
forecasts_end = [] | |
for i in range(n_samples): | |
cur_data_pred = get_scatter_data( | |
pred[:, i], | |
mask_pred, | |
mode="lines", | |
line=dict(width=2, color="red"), | |
name="Forecast", | |
) | |
data_pred += cur_data_pred | |
forecast_end = get_scatter_data( | |
pred[:, i, -1:], | |
mask_pred[:, -1:], | |
mode="markers", | |
marker=dict(color="red", size=marker_size/2, opacity=0.5, symbol="x"), | |
name="Forecast end", | |
) | |
forecasts_end += forecast_end | |
static_data = data_map + data_x + data_y + data_pred + ego_present + agent_present + forecasts_end | |
animation_opacity = 0.5 | |
frames_x = [ | |
go.Frame( | |
data=[ | |
go.Scatter( | |
x=x[mask_x[:, k], k, 0], | |
y=x[mask_x[:, k], k, 1], | |
mode="markers", | |
opacity=animation_opacity, | |
marker=dict(color="black", size=marker_size), | |
showlegend=False, | |
), | |
go.Scatter( | |
x=x[0:1, k, 0], | |
y=x[0:1, k, 1], | |
mode="markers", | |
opacity=animation_opacity, | |
marker=dict(color="blue", size=marker_size), | |
showlegend=False, | |
), | |
] | |
) | |
for k in range(x.shape[1]) | |
] | |
frames_y_pred = [] | |
for k in range(y.shape[1]): | |
cur_gt_agent_data = go.Scatter( | |
x=y[1:2][mask_y[1:2, k], k, 0], | |
y=y[1:2][mask_y[1:2, k], k, 1], | |
mode="markers", | |
opacity=animation_opacity, | |
marker=dict(color="green", size=marker_size), | |
) | |
cur_gt_future_data = go.Scatter( | |
x=y[2:][mask_y[2:, k], k, 0], | |
y=y[2:][mask_y[2:, k], k, 1], | |
mode="markers", | |
opacity=animation_opacity, | |
marker=dict(color="black", size=marker_size), | |
) | |
cur_pred_data = [] | |
for i in range(n_samples): | |
cur_pred_data.append( | |
go.Scatter( | |
x=pred[mask_pred[:, k], i, k, 0], | |
y=pred[mask_pred[:, k], i, k, 1], | |
mode="markers", | |
opacity=animation_opacity, | |
marker=dict(color="red", size=marker_size), | |
showlegend=False, | |
) | |
) | |
cur_ego_data = go.Scatter( | |
x=y[0:1, k, 0], | |
y=y[0:1, k, 1], | |
mode="markers", | |
opacity=animation_opacity, | |
marker=dict(color="blue", size=marker_size), | |
) | |
cur_data = [cur_gt_agent_data, cur_gt_future_data, *cur_pred_data, cur_ego_data] | |
frame = go.Frame(data=cur_data) | |
frames_y_pred.append(frame) | |
return {"frames": frames_x + frames_y_pred, "data": static_data} | |
def prediction_plot( | |
predictor: LitTrajectoryPredictor, | |
dataset: Dataset, | |
index: int, | |
risk_level: float, | |
n_samples: int = 1, | |
use_biaser: bool = True, | |
) -> go.Figure: | |
range_radius = 80 | |
if use_biaser: | |
risk_level = float(risk_level) | |
else: | |
risk_level = None | |
layout = go.Layout( | |
xaxis=dict( | |
range=[-0.5*range_radius, 1.5*range_radius], | |
autorange=False, | |
zeroline=False, | |
), | |
yaxis=dict( | |
range=[-range_radius, range_radius], | |
autorange=False, | |
zeroline=False, | |
), | |
title_text="Road Scene", | |
hovermode="closest", | |
width=800, | |
height=600, | |
updatemenus=[ | |
dict( | |
type="buttons", | |
buttons=[ | |
dict( | |
label="Play", | |
method="animate", | |
args=[ | |
None, | |
dict( | |
frame=dict(duration=100, redraw=False), | |
mode="immediate", | |
fromcurrent=True, | |
), | |
], | |
), | |
dict( | |
label="Pause", | |
method="animate", | |
args=[[None], {"frame": {"duration": 0, "redraw": False}, | |
"mode": "immediate", | |
"transition": {"duration": 0}}], | |
) | |
], | |
) | |
], | |
) | |
fig = go.Figure( | |
**build_data(predictor, dataset, index, risk_level, n_samples), | |
layout=layout, | |
) | |
fig.update_geos(projection_type="equirectangular", visible=True, resolution=110) | |
return fig | |
def get_figure( | |
predictor: LitTrajectoryPredictor, | |
dataset: Dataset, | |
index: int, | |
risk_level: float, | |
n_samples: int, | |
) -> go.Figure: | |
fig = prediction_plot( | |
predictor, dataset, index, risk_level, n_samples, use_biaser=True | |
) | |
fig.update_layout() | |
return fig | |
def update_figure( | |
predictor: LitTrajectoryPredictor, | |
dataset: Dataset, | |
index: int, | |
risk_level: float, | |
n_samples: int, | |
image = None | |
) -> go.Figure: | |
fig = prediction_plot( | |
predictor, dataset, index, risk_level, n_samples, use_biaser=True | |
) | |
fig.update_layout() | |
return fig | |
def load_predictor_from_hf(model_source: str = "TRI-ML/risk_biased_model", config_name: str="learning_config.py", checkpoint_name: str = "last.ckpt", device: str = "cpu") -> Tuple[LitTrajectoryPredictor, Dataset]: | |
config_file = hf_hub_download(model_source, filename=config_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN')) | |
ckpt = torch.load(hf_hub_download(model_source, filename=checkpoint_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN')), map_location="cpu") | |
cfg = Config.fromfile(config_file) | |
predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory) | |
predictor = load_weights(predictor, ckpt) | |
predictor.eval() | |
predictor = predictor.to(device) | |
return predictor | |
def load_dataset_from_hf(data_source: str = "jmercat/risk_biased_dataset") -> Dataset: | |
dataset = load_dataset(data_source, split="test") | |
return dataset | |
def main(load_from=None, cfg_path=None): | |
# Define the device to use | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print("Getting dataset") | |
dataset = load_dataset_from_hf() | |
if load_from is not None: | |
cfg = Config.fromfile(cfg_path) | |
predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory) | |
predictor = load_weights(predictor, torch.load(load_from, map_location="cpu")) | |
else: | |
print("Getting model.") | |
predictor = load_predictor_from_hf(device=device) | |
ui_update_fn = partial(update_figure, predictor, dataset) | |
# Do the same thing as above but using the gradio blocks API | |
with gr.Blocks() as interface: | |
gr.Markdown( | |
""" | |
# Risk-Aware Prediction | |
Make predictions for the green agent with a risk-seeking bias towards the ego vehicle in blue. | |
The risk level is a value between 0 and 1, where 0 is not risk-seeking and 1 is the most risk-seeking. | |
Once the sliders are set, click the "Run" button to see the predictions. | |
The play button will animate the prediction over time (it is slow especially with many samples). | |
For more information, see the paper [RAP: Risk-Aware Prediction for Robust Planning](https://arxiv.org/abs/2210.01368) published at [CoRL 2022](https://corl2022.org/). | |
""") | |
initial_index = 27 | |
initial_n_samples = 10 | |
image = gr.Plot(get_figure(predictor, dataset, initial_index, 0, initial_n_samples)) | |
interface.queue() | |
index = gr.Slider( | |
minimum=0, | |
maximum=len(dataset)-1, | |
step=1, | |
value=initial_index, | |
label="Index", | |
) | |
risk_level = gr.Slider(minimum=0, maximum=1, step=0.01, label="Risk") | |
n_samples = gr.Slider(minimum=1, maximum=20, step=1, value=initial_n_samples, label="Number of prediction samples") | |
button = gr.Button(label="Run") | |
# Removed the interactive plot because it was running on the first change and all changes made during computation were ignored | |
# This caused the plot to be out of sync with the sliders | |
# index.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image) | |
# risk_level.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image) | |
# n_samples.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image) | |
button.click(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image) | |
interface.launch(debug=False) | |
if __name__ == "__main__": | |
fire.Fire(main) | |