jmercat commited on
Commit
e683d10
·
1 Parent(s): 349aedb

separate load data and load model, set debug False

Browse files
scripts/scripts_utils/plotly_interface.py CHANGED
@@ -244,7 +244,7 @@ def prediction_plot(
244
  n_samples: int = 1,
245
  use_biaser: bool = True,
246
  ) -> go.Figure:
247
- range_radius = 70
248
  if use_biaser:
249
  risk_level = float(risk_level)
250
  else:
@@ -262,8 +262,8 @@ def prediction_plot(
262
  ),
263
  title_text="Road Scene",
264
  hovermode="closest",
265
- width=1200,
266
- height=600,
267
  updatemenus=[
268
  dict(
269
  type="buttons",
@@ -332,8 +332,7 @@ def update_figure(
332
 
333
  return fig
334
 
335
- def load_from_huggingface(model_source: str = "TRI-ML/risk_biased_model", data_source: str = "jmercat/risk_biased_dataset", config_name: str="learning_config.py", checkpoint_name: str = "last.ckpt", device: str = "cpu") -> Tuple[LitTrajectoryPredictor, Dataset]:
336
- dataset = load_dataset(data_source, split="test")
337
  config_file = hf_hub_download(model_source, filename=config_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN'))
338
  ckpt = torch.load(hf_hub_download(model_source, filename=checkpoint_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN')), map_location="cpu")
339
  cfg = Config.fromfile(config_file)
@@ -342,23 +341,31 @@ def load_from_huggingface(model_source: str = "TRI-ML/risk_biased_model", data_
342
  predictor.eval()
343
  predictor = predictor.to(device)
344
 
345
- return predictor, dataset
 
 
 
 
346
 
347
  def main(load_from=None, cfg_path=None):
348
  # Define the device to use
349
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
  # Do the same thing as above but using the gradio blocks API
352
  with gr.Blocks() as interface:
353
-
354
- predictor, dataset = load_from_huggingface(device=device)
355
-
356
- if load_from is not None:
357
- cfg = Config.fromfile(cfg_path)
358
- predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory)
359
- predictor = load_weights(predictor, torch.load(load_from, map_location="cpu"))
360
 
361
- ui_update_fn = partial(update_figure, predictor, dataset)
362
  gr.Markdown(
363
  """
364
  # Risk-Aware Prediction
@@ -391,7 +398,7 @@ def main(load_from=None, cfg_path=None):
391
  # n_samples.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
392
  button.click(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
393
 
394
- interface.launch(debug=True)
395
 
396
 
397
  if __name__ == "__main__":
 
244
  n_samples: int = 1,
245
  use_biaser: bool = True,
246
  ) -> go.Figure:
247
+ range_radius = 50
248
  if use_biaser:
249
  risk_level = float(risk_level)
250
  else:
 
262
  ),
263
  title_text="Road Scene",
264
  hovermode="closest",
265
+ width=600,
266
+ height=300,
267
  updatemenus=[
268
  dict(
269
  type="buttons",
 
332
 
333
  return fig
334
 
335
+ def load_predictor_from_hf(model_source: str = "TRI-ML/risk_biased_model", config_name: str="learning_config.py", checkpoint_name: str = "last.ckpt", device: str = "cpu") -> Tuple[LitTrajectoryPredictor, Dataset]:
 
336
  config_file = hf_hub_download(model_source, filename=config_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN'))
337
  ckpt = torch.load(hf_hub_download(model_source, filename=checkpoint_name, use_auth_token=os.getenv('SECRET_AUTH_TOKEN')), map_location="cpu")
338
  cfg = Config.fromfile(config_file)
 
341
  predictor.eval()
342
  predictor = predictor.to(device)
343
 
344
+ return predictor
345
+
346
+ def load_dataset_from_hf(data_source: str = "jmercat/risk_biased_dataset") -> Dataset:
347
+ dataset = load_dataset(data_source, split="test")
348
+ return dataset
349
 
350
  def main(load_from=None, cfg_path=None):
351
  # Define the device to use
352
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
353
+ print("Getting dataset")
354
+ dataset = load_dataset_from_hf()
355
+
356
+ if load_from is not None:
357
+ cfg = Config.fromfile(cfg_path)
358
+ predictor = get_predictor(cfg, WaymoDataloaders.unnormalize_trajectory)
359
+ predictor = load_weights(predictor, torch.load(load_from, map_location="cpu"))
360
+ else:
361
+ print("Getting model.")
362
+ predictor = load_predictor_from_hf(device=device)
363
+
364
+ ui_update_fn = partial(update_figure, predictor, dataset)
365
 
366
  # Do the same thing as above but using the gradio blocks API
367
  with gr.Blocks() as interface:
 
 
 
 
 
 
 
368
 
 
369
  gr.Markdown(
370
  """
371
  # Risk-Aware Prediction
 
398
  # n_samples.change(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
399
  button.click(ui_update_fn, inputs=[index, risk_level, n_samples, image], outputs=image)
400
 
401
+ interface.launch(debug=False)
402
 
403
 
404
  if __name__ == "__main__":