jmercat's picture
Change window centering
0bc5996
raw
history blame
13.5 kB
from datasets import load_dataset, Dataset
import fire
from functools import partial, update_wrapper
import numpy
import os
from typing import Dict, Iterable, Tuple
import sys
import time
import torch
import gradio as gr
from huggingface_hub import hf_hub_download
from mmcv import Config
import plotly.graph_objects as go
from torch.utils.data.dataloader import DataLoader
from risk_biased.utils.load_model import get_predictor
from risk_biased.utils.torch_utils import load_weights
from risk_biased.utils.waymo_dataloader import WaymoDataloaders
from risk_biased.predictors.biased_predictor import (
LitTrajectoryPredictor,
)
def to_numpy(**kwargs):
dic_outputs = {}
for k, v in kwargs.items():
dic_outputs[k] = v.detach().cpu().numpy()
return dic_outputs
def get_scatter_data(x, mask_x, name, **kwargs):
return [
go.Scatter(
x=x[k, mask_x[k], 0],
y=x[k, mask_x[k], 1],
showlegend=k == 0,
name=name,
**kwargs,
)
for k in range(x.shape[0])
]
def configuration_paths() -> Iterable[os.PathLike]:
working_dir = os.path.dirname(os.path.realpath(__file__))
return [
os.path.join(
working_dir,
"../../risk_biased/config",
config_file,
)
for config_file in ("learning_config.py", "waymo_config.py")
]
def load_item(index: int, dataset: Dataset, device: str = "cpu") -> Tuple:
x = torch.from_numpy(numpy.array(dataset[index]["x"]).astype(numpy.float32)).to(device)
mask_x = torch.from_numpy(numpy.array(dataset[index]["mask_x"]).astype(numpy.bool8)).to(device)
y = torch.from_numpy(numpy.array(dataset[index]["y"]).astype(numpy.float32)).to(device)
mask_y = torch.from_numpy(numpy.array(dataset[index]["mask_y"]).astype(numpy.bool8)).to(device)
mask_loss = torch.from_numpy( numpy.array(dataset[index]["mask_loss"]).astype(numpy.bool8)).to(device)
map_data = torch.from_numpy(numpy.array(dataset[index]["map_data"]).astype(numpy.float32)).to(device)
mask_map = torch.from_numpy(numpy.array(dataset[index]["mask_map"]).astype(numpy.bool8)).to(device)
offset = torch.from_numpy(numpy.array(dataset[index]["offset"]).astype(numpy.float32)).to(device)
x_ego = torch.from_numpy(numpy.array(dataset[index]["x_ego"]).astype(numpy.float32)).to(device)
y_ego = torch.from_numpy(numpy.array(dataset[index]["y_ego"]).astype(numpy.float32)).to(device)
return (x, mask_x, map_data, mask_map, offset, x_ego, y_ego), y, mask_y, mask_loss
def build_data(
predictor: LitTrajectoryPredictor,
dataset: Dataset,
index: int,
risk_level: float,
n_samples: int,
) -> Dict[str, go.Scatter]:
assert n_samples >= 1
batch, y, mask_y, mask_loss = load_item(index, dataset, predictor.device)
predictions = predictor.predict_step(
batch=batch,
risk_level=risk_level,
n_samples=n_samples,
)
offset = batch[4]
y = predictor._unnormalize_trajectory(y, offset)
x = predictor._unnormalize_trajectory(batch[0], offset)
numpy_data = to_numpy(
predictions=predictions,
y=y,
mask_y=mask_y,
x=x,
mask_x=batch[1],
map_data=batch[2],
mask_map=batch[3],
mask_pred=mask_loss,
)
x = numpy_data["x"][0]
mask_x = numpy_data["mask_x"][0]
y = numpy_data["y"][0]
mask_y = numpy_data["mask_y"][0]
pred = numpy_data["predictions"][0]
mask_pred = numpy_data["mask_pred"][0]
map_data = numpy_data["map_data"][0]
mask_map = numpy_data["mask_map"][0]
data_x = get_scatter_data(
x,
mask_x,
mode="lines",
line=dict(width=2, color="black"),
name="Past",
)
ego_present = get_scatter_data(
x=x[0:1, -1:],
mask_x=mask_x[0:1, -1:],
mode="markers",
marker=dict(color="blue", size=20, opacity=0.5),
name="Ego",
)
agent_present = get_scatter_data(
x=x[1:2, -1:],
mask_x=mask_x[1:2, -1:],
mode="markers",
marker=dict(color="green", size=20, opacity=0.5),
name="Agent",
)
data_y = get_scatter_data(
y,
mask_y,
mode="lines",
line=dict(width=2, color="green"),
name="Ground truth",
)
data_map = get_scatter_data(
map_data,
mask_map,
mode="lines",
line=dict(width=15, color="gray"),
opacity=0.3,
name="Centerline",
)
data_pred = []
forecasts_end = []
for i in range(n_samples):
cur_data_pred = get_scatter_data(
pred[:, i],
mask_pred,
mode="lines",
line=dict(width=2, color="red"),
name="Forecast",
)
data_pred += cur_data_pred
forecast_end = get_scatter_data(
pred[:, i, -1:],
mask_pred[:, -1:],
mode="markers",
marker=dict(color="red", size=10, opacity=0.5, symbol="x"),
name="Forecast end",
)
forecasts_end += forecast_end
static_data = data_map + data_x + data_y + data_pred + ego_present + agent_present + forecasts_end
animation_opacity = 0.5
frames_x = [
go.Frame(
data=[
go.Scatter(
x=x[mask_x[:, k], k, 0],
y=x[mask_x[:, k], k, 1],
mode="markers",
opacity=animation_opacity,
marker=dict(color="black", size=15),
showlegend=False,
),
go.Scatter(
x=x[0:1, k, 0],
y=x[0:1, k, 1],
mode="markers",
opacity=animation_opacity,
marker=dict(color="blue", size=15),
showlegend=False,
),
]
)
for k in range(x.shape[1])
]
frames_y_pred = []
for k in range(y.shape[1]):
cur_gt_agent_data = go.Scatter(
x=y[1:2][mask_y[1:2, k], k, 0],
y=y[1:2][mask_y[1:2, k], k, 1],
mode="markers",
opacity=animation_opacity,
marker=dict(color="green", size=15),
)
cur_gt_future_data = go.Scatter(
x=y[2:][mask_y[2:, k], k, 0],
y=y[2:][mask_y[2:, k], k, 1],
mode="markers",
opacity=animation_opacity,
marker=dict(color="black", size=15),
)
cur_pred_data = []
for i in range(n_samples):
cur_pred_data.append(
go.Scatter(
x=pred[mask_pred[:, k], i, k, 0],
y=pred[mask_pred[:, k], i, k, 1],
mode="markers",
opacity=animation_opacity,
marker=dict(color="red", size=15),
showlegend=False,
)
)
cur_ego_data = go.Scatter(
x=y[0:1, k, 0],
y=y[0:1, k, 1],
mode="markers",
opacity=animation_opacity,
marker=dict(color="blue", size=15),
)
cur_data = [cur_gt_agent_data, cur_gt_future_data, *cur_pred_data, cur_ego_data]
frame = go.Frame(data=cur_data)
frames_y_pred.append(frame)
return {"frames": frames_x + frames_y_pred, "data": static_data}
def prediction_plot(
predictor: LitTrajectoryPredictor,
dataset: Dataset,
index: int,
risk_level: float,
n_samples: int = 1,
use_biaser: bool = True,
) -> go.Figure:
range_radius = 80
if use_biaser:
risk_level = float(risk_level)
else:
risk_level = None
layout = go.Layout(
xaxis=dict(
range=[-0.5*range_radius, 1.5*range_radius],
autorange=False,
zeroline=False,
),
yaxis=dict(
range=[-range_radius, range_radius],
autorange=False,
zeroline=False,
),
title_text="Road Scene",
hovermode="closest",
width=600,
height=300,
updatemenus=[
dict(
type="buttons",
buttons=[
dict(
label="Play",
method="animate",
args=[
None,
dict(
transition=dict(duration=100),
frame=dict(duration=100, redraw=False),
mode="immediate",
fromcurrent=True,
),
],
),
dict(
label="Pause",
method="animate",
args=[[None], {"frame": {"duration": 0, "redraw": False},
"mode": "immediate",
"transition": {"duration": 0}}],
)
],
)
],
)
fig = go.Figure(
**build_data(predictor, dataset, index, risk_level, n_samples),
layout=layout,
)
fig.update_geos(projection_type="equirectangular", visible=True, resolution=110)
return fig
def get_figure(
predictor: LitTrajectoryPredictor,
dataset: Dataset,
index: int,
risk_level: float,
n_samples: int,
) -> go.Figure:
fig = prediction_plot(
predictor, dataset, index, risk_level, n_samples, use_biaser=True
)
fig.update_layout()
return fig
def update_figure(
predictor: LitTrajectoryPredictor,
dataset: Dataset,
index: int,
risk_level: float,
n_samples: int,
image = None
) -> go.Figure:
fig = prediction_plot(
predictor, dataset, index, risk_level, n_samples, use_biaser=True
)
fig.update_layout()
return fig
def load_predictor_from_hf(model_source: str = "TRI-ML/risk_biased_model", config_name: str="learning_config.py", checkpoint_name: str = "last.ckpt", device: str = "cpu") -> Tuple[LitTrajectoryPredictor, Dataset]:
config_file = hf_hub_download(model_source, filename=config_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN'))
ckpt = torch.load(hf_hub_download(model_source, filename=checkpoint_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN')), map_location="cpu")
cfg = Config.fromfile(config_file)
predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory)
predictor = load_weights(predictor, ckpt)
predictor.eval()
predictor = predictor.to(device)
return predictor
def load_dataset_from_hf(data_source: str = "jmercat/risk_biased_dataset") -> Dataset:
dataset = load_dataset(data_source, split="test")
return dataset
def main(load_from=None, cfg_path=None):
# Define the device to use
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Getting dataset")
dataset = load_dataset_from_hf()
if load_from is not None:
cfg = Config.fromfile(cfg_path)
predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory)
predictor = load_weights(predictor, torch.load(load_from, map_location="cpu"))
else:
print("Getting model.")
predictor = load_predictor_from_hf(device=device)
ui_update_fn = partial(update_figure, predictor, dataset)
# Do the same thing as above but using the gradio blocks API
with gr.Blocks() as interface:
gr.Markdown(
"""
# Risk-Aware Prediction
Make predictions for the green agent with a risk-seeking bias towards the ego vehicle in blue.
The risk level is a value between 0 and 1, where 0 is not risk-seeking and 1 is the most risk-seeking.
If "Use Biased Encoder" is unchecked, the risk level is ignored and the model will make predictions without a risk-seeking bias.
For more information, see the paper [RAP: Risk-Aware Prediction for Robust Planning](https://arxiv.org/abs/2210.01368) published at CoRL 2022.
""")
initial_index = 27
initial_n_samples = 10
image = gr.Plot(get_figure(predictor, dataset, initial_index, 0, initial_n_samples))
interface.queue()
index = gr.Slider(
minimum=0,
maximum=len(dataset)-1,
step=1,
value=initial_index,
label="Index",
)
risk_level = gr.Slider(minimum=0, maximum=1, step=0.01, label="Risk")
n_samples = gr.Slider(minimum=1, maximum=20, step=1, value=initial_n_samples, label="Num Samples")
button = gr.Button(label="Re-sample")
# Removed the interactive plot because it was running on the first change and all changes made during computation were ignored
# This caused the plot to be out of sync with the sliders
# index.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
# risk_level.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
# n_samples.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
button.click(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
interface.launch(debug=False)
if __name__ == "__main__":
fire.Fire(main)