Spaces:
Running
Running
separate load data and load model, set debug False
Browse files
scripts/scripts_utils/plotly_interface.py
CHANGED
@@ -244,7 +244,7 @@ def prediction_plot(
|
|
244 |
n_samples: int = 1,
|
245 |
use_biaser: bool = True,
|
246 |
) -> go.Figure:
|
247 |
-
range_radius =
|
248 |
if use_biaser:
|
249 |
risk_level = float(risk_level)
|
250 |
else:
|
@@ -262,8 +262,8 @@ def prediction_plot(
|
|
262 |
),
|
263 |
title_text="Road Scene",
|
264 |
hovermode="closest",
|
265 |
-
width=
|
266 |
-
height=
|
267 |
updatemenus=[
|
268 |
dict(
|
269 |
type="buttons",
|
@@ -332,8 +332,7 @@ def update_figure(
|
|
332 |
|
333 |
return fig
|
334 |
|
335 |
-
def
|
336 |
-
dataset = load_dataset(data_source, split="test")
|
337 |
config_file = hf_hub_download(model_source, filename=config_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN'))
|
338 |
ckpt = torch.load(hf_hub_download(model_source, filename=checkpoint_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN')), map_location="cpu")
|
339 |
cfg = Config.fromfile(config_file)
|
@@ -342,23 +341,31 @@ def load_from_huggingface(model_source: str = "TRI-ML/risk_biased_model", data_
|
|
342 |
predictor.eval()
|
343 |
predictor = predictor.to(device)
|
344 |
|
345 |
-
return predictor
|
|
|
|
|
|
|
|
|
346 |
|
347 |
def main(load_from=None, cfg_path=None):
|
348 |
# Define the device to use
|
349 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
|
351 |
# Do the same thing as above but using the gradio blocks API
|
352 |
with gr.Blocks() as interface:
|
353 |
-
|
354 |
-
predictor, dataset = load_from_huggingface(device=device)
|
355 |
-
|
356 |
-
if load_from is not None:
|
357 |
-
cfg = Config.fromfile(cfg_path)
|
358 |
-
predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory)
|
359 |
-
predictor = load_weights(predictor, torch.load(load_from, map_location="cpu"))
|
360 |
|
361 |
-
ui_update_fn = partial(update_figure, predictor, dataset)
|
362 |
gr.Markdown(
|
363 |
"""
|
364 |
# Risk-Aware Prediction
|
@@ -391,7 +398,7 @@ def main(load_from=None, cfg_path=None):
|
|
391 |
# n_samples.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
|
392 |
button.click(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
|
393 |
|
394 |
-
interface.launch(debug=
|
395 |
|
396 |
|
397 |
if __name__ == "__main__":
|
|
|
244 |
n_samples: int = 1,
|
245 |
use_biaser: bool = True,
|
246 |
) -> go.Figure:
|
247 |
+
range_radius = 50
|
248 |
if use_biaser:
|
249 |
risk_level = float(risk_level)
|
250 |
else:
|
|
|
262 |
),
|
263 |
title_text="Road Scene",
|
264 |
hovermode="closest",
|
265 |
+
width=600,
|
266 |
+
height=300,
|
267 |
updatemenus=[
|
268 |
dict(
|
269 |
type="buttons",
|
|
|
332 |
|
333 |
return fig
|
334 |
|
335 |
+
def load_predictor_from_hf(model_source: str = "TRI-ML/risk_biased_model", config_name: str="learning_config.py", checkpoint_name: str = "last.ckpt", device: str = "cpu") -> Tuple[LitTrajectoryPredictor, Dataset]:
|
|
|
336 |
config_file = hf_hub_download(model_source, filename=config_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN'))
|
337 |
ckpt = torch.load(hf_hub_download(model_source, filename=checkpoint_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN')), map_location="cpu")
|
338 |
cfg = Config.fromfile(config_file)
|
|
|
341 |
predictor.eval()
|
342 |
predictor = predictor.to(device)
|
343 |
|
344 |
+
return predictor
|
345 |
+
|
346 |
+
def load_dataset_from_hf(data_source: str = "jmercat/risk_biased_dataset") -> Dataset:
|
347 |
+
dataset = load_dataset(data_source, split="test")
|
348 |
+
return dataset
|
349 |
|
350 |
def main(load_from=None, cfg_path=None):
|
351 |
# Define the device to use
|
352 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
353 |
+
print("Getting dataset")
|
354 |
+
dataset = load_dataset_from_hf()
|
355 |
+
|
356 |
+
if load_from is not None:
|
357 |
+
cfg = Config.fromfile(cfg_path)
|
358 |
+
predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory)
|
359 |
+
predictor = load_weights(predictor, torch.load(load_from, map_location="cpu"))
|
360 |
+
else:
|
361 |
+
print("Getting model.")
|
362 |
+
predictor = load_predictor_from_hf(device=device)
|
363 |
+
|
364 |
+
ui_update_fn = partial(update_figure, predictor, dataset)
|
365 |
|
366 |
# Do the same thing as above but using the gradio blocks API
|
367 |
with gr.Blocks() as interface:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
|
|
|
369 |
gr.Markdown(
|
370 |
"""
|
371 |
# Risk-Aware Prediction
|
|
|
398 |
# n_samples.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
|
399 |
button.click(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
|
400 |
|
401 |
+
interface.launch(debug=False)
|
402 |
|
403 |
|
404 |
if __name__ == "__main__":
|