kaifz commited on
Commit
5de6891
·
1 Parent(s): a1f54d5
Files changed (1) hide show
  1. app.py +35 -67
app.py CHANGED
@@ -113,6 +113,7 @@ def fps_new(x, enabled, n, device, random_start=False):
113
 
114
  class DynamicsVisualizer:
115
 
 
116
  def __init__(self):
117
  wp.init()
118
 
@@ -172,6 +173,7 @@ class DynamicsVisualizer:
172
 
173
  self.clear()
174
 
 
175
  def clear(self, clear_params=True):
176
  self.metadata = {}
177
  self.config = {}
@@ -204,6 +206,7 @@ class DynamicsVisualizer:
204
  self.material = None
205
  self.friction = None
206
 
 
207
  def load_scaniverse(self, data_path):
208
 
209
  ### load splat params
@@ -300,7 +303,7 @@ class DynamicsVisualizer:
300
  # self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32)
301
  self.state['gripper_radius'] = cfg.model.gripper_radius
302
 
303
-
304
  def load_params(self, params_path, remove_low_opa=True, remove_black=False):
305
  pts, colors, scales, quats, opacities = read_splat(params_path)
306
 
@@ -366,7 +369,7 @@ class DynamicsVisualizer:
366
  self.state['clip_bound'] = torch.tensor([self.cfg.model.clip_bound], dtype=torch.float32)
367
  self.state['enabled'] = torch.ones(n_particles, dtype=torch.bool)
368
 
369
-
370
  def set_camera(self, w, h, intr, w2c=None, R=None, t=None, near=0.01, far=100.0):
371
  if w2c is None:
372
  assert R is not None and t is not None
@@ -379,6 +382,7 @@ class DynamicsVisualizer:
379
  }
380
  self.config = {'near': near, 'far': far}
381
 
 
382
  def load_eef(self, grippers=None, eef_t=None):
383
  assert self.state['prev_key_pos'] is None
384
 
@@ -401,6 +405,7 @@ class DynamicsVisualizer:
401
  # self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32) + 0.001
402
  self.state['gripper_radius'] = self.cfg.model.gripper_radius
403
 
 
404
  def load_preprocess_metadata(self, p_x_orig):
405
  cfg = self.cfg
406
  dx = cfg.sim.num_grids[-1]
@@ -437,6 +442,7 @@ class DynamicsVisualizer:
437
  'global_translation': global_translation,
438
  }
439
 
 
440
  @torch.no_grad
441
  def render(self, render_data, cam_id, bg=[0.7, 0.7, 0.7]):
442
  render_data = {k: v.to(self.device) for k, v in render_data.items()}
@@ -446,6 +452,7 @@ class DynamicsVisualizer:
446
  im, _, depth, = GaussianRasterizer(raster_settings=cam)(**render_data)
447
  return im, depth
448
 
 
449
  def knn_relations(self, bones):
450
  k = self.k_rel
451
  knn = NearestNeighbors(n_neighbors=k+1, algorithm='kd_tree').fit(bones.detach().cpu().numpy())
@@ -453,6 +460,7 @@ class DynamicsVisualizer:
453
  indices = indices[:, 1:] # exclude self
454
  return indices
455
 
 
456
  def knn_weights_brute(self, bones, pts):
457
  k = self.k_wgt
458
  dist = torch.norm(pts[:, None] - bones, dim=-1) # (n_pts, n_bones)
@@ -465,6 +473,7 @@ class DynamicsVisualizer:
465
  weights_all[torch.arange(pts.shape[0])[:, None], indices] = weights
466
  return weights_all
467
 
 
468
  def update_camera(self, k, w2c, w=None, h=None, near=0.01, far=100.0):
469
  self.metadata['k'] = k
470
  self.metadata['w2c'] = w2c
@@ -475,6 +484,7 @@ class DynamicsVisualizer:
475
  self.config['near'] = near
476
  self.config['far'] = far
477
 
 
478
  def init_model(self, batch_size, num_steps, num_particles, ckpt_path=None):
479
  self.cfg.sim.num_steps = num_steps
480
  cfg = self.cfg
@@ -510,11 +520,13 @@ class DynamicsVisualizer:
510
  self.material = material
511
  self.friction = friction
512
 
 
513
  def reload_model(self, num_steps): # only change num_steps
514
  self.cfg.sim.num_steps = num_steps
515
  sim = CacheDiffSimWithFrictionBatch(self.cfg, num_steps, 1, self.wp_device, requires_grad=True)
516
  self.sim = sim
517
 
 
518
  @torch.no_grad
519
  def step(self):
520
  cfg = self.cfg
@@ -604,6 +616,7 @@ class DynamicsVisualizer:
604
  self.state['sub_pos'] = None
605
  # self.state['sub_pos_timestamps'] = None
606
 
 
607
  def preprocess_x(self, p_x): # viewer frame to model frame (not data frame)
608
  R = self.preprocess_metadata['R']
609
  R_viewer = self.preprocess_metadata['R_viewer']
@@ -621,6 +634,7 @@ class DynamicsVisualizer:
621
 
622
  return p_x
623
 
 
624
  def preprocess_gripper(self, grippers): # viewer frame to model frame (not data frame)
625
  R = self.preprocess_metadata['R']
626
  R_viewer = self.preprocess_metadata['R_viewer']
@@ -633,6 +647,7 @@ class DynamicsVisualizer:
633
 
634
  return grippers
635
 
 
636
  def inverse_preprocess_x(self, p_x): # model frame (not data frame) to viewer frame
637
  R = self.preprocess_metadata['R']
638
  R_viewer = self.preprocess_metadata['R_viewer']
@@ -645,6 +660,7 @@ class DynamicsVisualizer:
645
 
646
  return p_x
647
 
 
648
  def inverse_preprocess_gripper(self, grippers): # model frame (not data frame) to viewer frame
649
  R = self.preprocess_metadata['R']
650
  R_viewer = self.preprocess_metadata['R_viewer']
@@ -657,55 +673,20 @@ class DynamicsVisualizer:
657
 
658
  return grippers
659
 
660
- def preprocess_gs(self, params):
661
- if isinstance(params, dict):
662
- xyz = params['means3D']
663
- rgb = params['rgb_colors']
664
- quat = torch.nn.functional.normalize(params['unnorm_rotations'])
665
- opa = torch.sigmoid(params['logit_opacities'])
666
- scales = torch.exp(params['log_scales'])
667
- else:
668
- assert isinstance(params, tuple)
669
- xyz, rgb, quat, opa, scales = params
670
-
671
- quat = torch.nn.functional.normalize(quat, dim=-1)
672
-
673
- # transform
674
- R = self.preprocess_metadata['R']
675
- R_viewer = self.preprocess_metadata['R_viewer']
676
- scale = self.preprocess_metadata['scale']
677
- global_translation = self.preprocess_metadata['global_translation']
678
-
679
- mat = quat2mat(quat)
680
- mat = R @ mat
681
- xyz = xyz @ R.T
682
- xyz = xyz * scale
683
- xyz += global_translation
684
- quat = mat2quat(mat)
685
- scales = scales * scale
686
-
687
- # viewer-specific transform (flip y and z)
688
- # model frame to viewer frame
689
- xyz = xyz @ R_viewer.T
690
- quat = mat2quat(R_viewer @ quat2mat(quat))
691
-
692
- t_viewer = -xyz.mean(dim=0)
693
- t_viewer[2] = 0
694
- xyz += t_viewer
695
- print('Overwriting t_viewer to be the planar mean of the object')
696
- self.preprocess_metadata['t_viewer'] = t_viewer
697
-
698
- if isinstance(params, dict):
699
- params['means3D'] = xyz
700
- params['rgb_colors'] = rgb
701
- params['unnorm_rotations'] = quat
702
- params['logit_opacities'] = opa
703
- params['log_scales'] = torch.log(scales)
704
- else:
705
- params = xyz, rgb, quat, opa, scales
706
 
 
 
 
 
 
 
 
707
  return params
708
 
 
709
  def preprocess_gs(self, params):
710
  if isinstance(params, dict):
711
  xyz = params['means3D']
@@ -755,6 +736,7 @@ class DynamicsVisualizer:
755
 
756
  return params
757
 
 
758
  def preprocess_bg_gs(self):
759
  t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params
760
  g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params
@@ -802,6 +784,7 @@ class DynamicsVisualizer:
802
  self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities
803
  self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities
804
 
 
805
  def update_rendervar(self, rendervar):
806
  p_x = self.state['x']
807
  p_x_viewer = self.inverse_preprocess_x(p_x)
@@ -878,6 +861,7 @@ class DynamicsVisualizer:
878
 
879
  return rendervar, rendervar_full
880
 
 
881
  def reset_state(self, params, visualize_image=False, init=False):
882
  xyz_0 = params['means3D']
883
  rgb_0 = params['rgb_colors']
@@ -945,12 +929,7 @@ class DynamicsVisualizer:
945
 
946
  return rendervar_init
947
 
948
-
949
-
950
-
951
-
952
-
953
-
954
  def reset(self):
955
  params = self.preprocess_gs(self.params)
956
  if self.with_bg:
@@ -992,6 +971,7 @@ class DynamicsVisualizer:
992
 
993
  return form_video, form_3dgs_pred
994
 
 
995
  def run_command(self, unit_command):
996
 
997
  os.system('rm -rf ' + str(root / 'log/temp/*'))
@@ -1095,21 +1075,9 @@ class DynamicsVisualizer:
1095
 
1096
  def on_click_run_zminus(self):
1097
  return self.run_command([0, 0, -5.0])
1098
-
1099
- def rotate(self, params, rot_mat):
1100
- scale = np.linalg.norm(rot_mat, axis=1, keepdims=True)
1101
-
1102
- params = {
1103
- 'means3D': pts,
1104
- 'rgb_colors': params['rgb_colors'],
1105
- 'log_scales': params['log_scales'],
1106
- 'unnorm_rotations': quats,
1107
- 'logit_opacities': params['logit_opacities'],
1108
- }
1109
- return params
1110
 
 
1111
  def launch(self, share=False):
1112
-
1113
  in_dir = root / 'log/gs/ckpts/rope_scene_1'
1114
  batch_size = 1
1115
  num_steps = 1
 
113
 
114
  class DynamicsVisualizer:
115
 
116
+ @spaces.GPU
117
  def __init__(self):
118
  wp.init()
119
 
 
173
 
174
  self.clear()
175
 
176
+ @spaces.GPU
177
  def clear(self, clear_params=True):
178
  self.metadata = {}
179
  self.config = {}
 
206
  self.material = None
207
  self.friction = None
208
 
209
+ @spaces.GPU
210
  def load_scaniverse(self, data_path):
211
 
212
  ### load splat params
 
303
  # self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32)
304
  self.state['gripper_radius'] = cfg.model.gripper_radius
305
 
306
+ @spaces.GPU
307
  def load_params(self, params_path, remove_low_opa=True, remove_black=False):
308
  pts, colors, scales, quats, opacities = read_splat(params_path)
309
 
 
369
  self.state['clip_bound'] = torch.tensor([self.cfg.model.clip_bound], dtype=torch.float32)
370
  self.state['enabled'] = torch.ones(n_particles, dtype=torch.bool)
371
 
372
+ @spaces.GPU
373
  def set_camera(self, w, h, intr, w2c=None, R=None, t=None, near=0.01, far=100.0):
374
  if w2c is None:
375
  assert R is not None and t is not None
 
382
  }
383
  self.config = {'near': near, 'far': far}
384
 
385
+ @spaces.GPU
386
  def load_eef(self, grippers=None, eef_t=None):
387
  assert self.state['prev_key_pos'] is None
388
 
 
405
  # self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32) + 0.001
406
  self.state['gripper_radius'] = self.cfg.model.gripper_radius
407
 
408
+ @spaces.GPU
409
  def load_preprocess_metadata(self, p_x_orig):
410
  cfg = self.cfg
411
  dx = cfg.sim.num_grids[-1]
 
442
  'global_translation': global_translation,
443
  }
444
 
445
+ @spaces.GPU
446
  @torch.no_grad
447
  def render(self, render_data, cam_id, bg=[0.7, 0.7, 0.7]):
448
  render_data = {k: v.to(self.device) for k, v in render_data.items()}
 
452
  im, _, depth, = GaussianRasterizer(raster_settings=cam)(**render_data)
453
  return im, depth
454
 
455
+ @spaces.GPU
456
  def knn_relations(self, bones):
457
  k = self.k_rel
458
  knn = NearestNeighbors(n_neighbors=k+1, algorithm='kd_tree').fit(bones.detach().cpu().numpy())
 
460
  indices = indices[:, 1:] # exclude self
461
  return indices
462
 
463
+ @spaces.GPU
464
  def knn_weights_brute(self, bones, pts):
465
  k = self.k_wgt
466
  dist = torch.norm(pts[:, None] - bones, dim=-1) # (n_pts, n_bones)
 
473
  weights_all[torch.arange(pts.shape[0])[:, None], indices] = weights
474
  return weights_all
475
 
476
+ @spaces.GPU
477
  def update_camera(self, k, w2c, w=None, h=None, near=0.01, far=100.0):
478
  self.metadata['k'] = k
479
  self.metadata['w2c'] = w2c
 
484
  self.config['near'] = near
485
  self.config['far'] = far
486
 
487
+ @spaces.GPU
488
  def init_model(self, batch_size, num_steps, num_particles, ckpt_path=None):
489
  self.cfg.sim.num_steps = num_steps
490
  cfg = self.cfg
 
520
  self.material = material
521
  self.friction = friction
522
 
523
+ @spaces.GPU
524
  def reload_model(self, num_steps): # only change num_steps
525
  self.cfg.sim.num_steps = num_steps
526
  sim = CacheDiffSimWithFrictionBatch(self.cfg, num_steps, 1, self.wp_device, requires_grad=True)
527
  self.sim = sim
528
 
529
+ @spaces.GPU
530
  @torch.no_grad
531
  def step(self):
532
  cfg = self.cfg
 
616
  self.state['sub_pos'] = None
617
  # self.state['sub_pos_timestamps'] = None
618
 
619
+ @spaces.GPU
620
  def preprocess_x(self, p_x): # viewer frame to model frame (not data frame)
621
  R = self.preprocess_metadata['R']
622
  R_viewer = self.preprocess_metadata['R_viewer']
 
634
 
635
  return p_x
636
 
637
+ @spaces.GPU
638
  def preprocess_gripper(self, grippers): # viewer frame to model frame (not data frame)
639
  R = self.preprocess_metadata['R']
640
  R_viewer = self.preprocess_metadata['R_viewer']
 
647
 
648
  return grippers
649
 
650
+ @spaces.GPU
651
  def inverse_preprocess_x(self, p_x): # model frame (not data frame) to viewer frame
652
  R = self.preprocess_metadata['R']
653
  R_viewer = self.preprocess_metadata['R_viewer']
 
660
 
661
  return p_x
662
 
663
+ @spaces.GPU
664
  def inverse_preprocess_gripper(self, grippers): # model frame (not data frame) to viewer frame
665
  R = self.preprocess_metadata['R']
666
  R_viewer = self.preprocess_metadata['R_viewer']
 
673
 
674
  return grippers
675
 
676
+ @spaces.GPU
677
+ def rotate(self, params, rot_mat):
678
+ scale = np.linalg.norm(rot_mat, axis=1, keepdims=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
679
 
680
+ params = {
681
+ 'means3D': pts,
682
+ 'rgb_colors': params['rgb_colors'],
683
+ 'log_scales': params['log_scales'],
684
+ 'unnorm_rotations': quats,
685
+ 'logit_opacities': params['logit_opacities'],
686
+ }
687
  return params
688
 
689
+ @spaces.GPU
690
  def preprocess_gs(self, params):
691
  if isinstance(params, dict):
692
  xyz = params['means3D']
 
736
 
737
  return params
738
 
739
+ @spaces.GPU
740
  def preprocess_bg_gs(self):
741
  t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params
742
  g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params
 
784
  self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities
785
  self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities
786
 
787
+ @spaces.GPU
788
  def update_rendervar(self, rendervar):
789
  p_x = self.state['x']
790
  p_x_viewer = self.inverse_preprocess_x(p_x)
 
861
 
862
  return rendervar, rendervar_full
863
 
864
+ @spaces.GPU
865
  def reset_state(self, params, visualize_image=False, init=False):
866
  xyz_0 = params['means3D']
867
  rgb_0 = params['rgb_colors']
 
929
 
930
  return rendervar_init
931
 
932
+ @spaces.GPU
 
 
 
 
 
933
  def reset(self):
934
  params = self.preprocess_gs(self.params)
935
  if self.with_bg:
 
971
 
972
  return form_video, form_3dgs_pred
973
 
974
+ @spaces.GPU
975
  def run_command(self, unit_command):
976
 
977
  os.system('rm -rf ' + str(root / 'log/temp/*'))
 
1075
 
1076
  def on_click_run_zminus(self):
1077
  return self.run_command([0, 0, -5.0])
 
 
 
 
 
 
 
 
 
 
 
 
1078
 
1079
+ @spaces.GPU
1080
  def launch(self, share=False):
 
1081
  in_dir = root / 'log/gs/ckpts/rope_scene_1'
1082
  batch_size = 1
1083
  num_steps = 1