update
Browse files
app.py
CHANGED
@@ -113,6 +113,7 @@ def fps_new(x, enabled, n, device, random_start=False):
|
|
113 |
|
114 |
class DynamicsVisualizer:
|
115 |
|
|
|
116 |
def __init__(self):
|
117 |
wp.init()
|
118 |
|
@@ -172,6 +173,7 @@ class DynamicsVisualizer:
|
|
172 |
|
173 |
self.clear()
|
174 |
|
|
|
175 |
def clear(self, clear_params=True):
|
176 |
self.metadata = {}
|
177 |
self.config = {}
|
@@ -204,6 +206,7 @@ class DynamicsVisualizer:
|
|
204 |
self.material = None
|
205 |
self.friction = None
|
206 |
|
|
|
207 |
def load_scaniverse(self, data_path):
|
208 |
|
209 |
### load splat params
|
@@ -300,7 +303,7 @@ class DynamicsVisualizer:
|
|
300 |
# self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32)
|
301 |
self.state['gripper_radius'] = cfg.model.gripper_radius
|
302 |
|
303 |
-
|
304 |
def load_params(self, params_path, remove_low_opa=True, remove_black=False):
|
305 |
pts, colors, scales, quats, opacities = read_splat(params_path)
|
306 |
|
@@ -366,7 +369,7 @@ class DynamicsVisualizer:
|
|
366 |
self.state['clip_bound'] = torch.tensor([self.cfg.model.clip_bound], dtype=torch.float32)
|
367 |
self.state['enabled'] = torch.ones(n_particles, dtype=torch.bool)
|
368 |
|
369 |
-
|
370 |
def set_camera(self, w, h, intr, w2c=None, R=None, t=None, near=0.01, far=100.0):
|
371 |
if w2c is None:
|
372 |
assert R is not None and t is not None
|
@@ -379,6 +382,7 @@ class DynamicsVisualizer:
|
|
379 |
}
|
380 |
self.config = {'near': near, 'far': far}
|
381 |
|
|
|
382 |
def load_eef(self, grippers=None, eef_t=None):
|
383 |
assert self.state['prev_key_pos'] is None
|
384 |
|
@@ -401,6 +405,7 @@ class DynamicsVisualizer:
|
|
401 |
# self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32) + 0.001
|
402 |
self.state['gripper_radius'] = self.cfg.model.gripper_radius
|
403 |
|
|
|
404 |
def load_preprocess_metadata(self, p_x_orig):
|
405 |
cfg = self.cfg
|
406 |
dx = cfg.sim.num_grids[-1]
|
@@ -437,6 +442,7 @@ class DynamicsVisualizer:
|
|
437 |
'global_translation': global_translation,
|
438 |
}
|
439 |
|
|
|
440 |
@torch.no_grad
|
441 |
def render(self, render_data, cam_id, bg=[0.7, 0.7, 0.7]):
|
442 |
render_data = {k: v.to(self.device) for k, v in render_data.items()}
|
@@ -446,6 +452,7 @@ class DynamicsVisualizer:
|
|
446 |
im, _, depth, = GaussianRasterizer(raster_settings=cam)(**render_data)
|
447 |
return im, depth
|
448 |
|
|
|
449 |
def knn_relations(self, bones):
|
450 |
k = self.k_rel
|
451 |
knn = NearestNeighbors(n_neighbors=k+1, algorithm='kd_tree').fit(bones.detach().cpu().numpy())
|
@@ -453,6 +460,7 @@ class DynamicsVisualizer:
|
|
453 |
indices = indices[:, 1:] # exclude self
|
454 |
return indices
|
455 |
|
|
|
456 |
def knn_weights_brute(self, bones, pts):
|
457 |
k = self.k_wgt
|
458 |
dist = torch.norm(pts[:, None] - bones, dim=-1) # (n_pts, n_bones)
|
@@ -465,6 +473,7 @@ class DynamicsVisualizer:
|
|
465 |
weights_all[torch.arange(pts.shape[0])[:, None], indices] = weights
|
466 |
return weights_all
|
467 |
|
|
|
468 |
def update_camera(self, k, w2c, w=None, h=None, near=0.01, far=100.0):
|
469 |
self.metadata['k'] = k
|
470 |
self.metadata['w2c'] = w2c
|
@@ -475,6 +484,7 @@ class DynamicsVisualizer:
|
|
475 |
self.config['near'] = near
|
476 |
self.config['far'] = far
|
477 |
|
|
|
478 |
def init_model(self, batch_size, num_steps, num_particles, ckpt_path=None):
|
479 |
self.cfg.sim.num_steps = num_steps
|
480 |
cfg = self.cfg
|
@@ -510,11 +520,13 @@ class DynamicsVisualizer:
|
|
510 |
self.material = material
|
511 |
self.friction = friction
|
512 |
|
|
|
513 |
def reload_model(self, num_steps): # only change num_steps
|
514 |
self.cfg.sim.num_steps = num_steps
|
515 |
sim = CacheDiffSimWithFrictionBatch(self.cfg, num_steps, 1, self.wp_device, requires_grad=True)
|
516 |
self.sim = sim
|
517 |
|
|
|
518 |
@torch.no_grad
|
519 |
def step(self):
|
520 |
cfg = self.cfg
|
@@ -604,6 +616,7 @@ class DynamicsVisualizer:
|
|
604 |
self.state['sub_pos'] = None
|
605 |
# self.state['sub_pos_timestamps'] = None
|
606 |
|
|
|
607 |
def preprocess_x(self, p_x): # viewer frame to model frame (not data frame)
|
608 |
R = self.preprocess_metadata['R']
|
609 |
R_viewer = self.preprocess_metadata['R_viewer']
|
@@ -621,6 +634,7 @@ class DynamicsVisualizer:
|
|
621 |
|
622 |
return p_x
|
623 |
|
|
|
624 |
def preprocess_gripper(self, grippers): # viewer frame to model frame (not data frame)
|
625 |
R = self.preprocess_metadata['R']
|
626 |
R_viewer = self.preprocess_metadata['R_viewer']
|
@@ -633,6 +647,7 @@ class DynamicsVisualizer:
|
|
633 |
|
634 |
return grippers
|
635 |
|
|
|
636 |
def inverse_preprocess_x(self, p_x): # model frame (not data frame) to viewer frame
|
637 |
R = self.preprocess_metadata['R']
|
638 |
R_viewer = self.preprocess_metadata['R_viewer']
|
@@ -645,6 +660,7 @@ class DynamicsVisualizer:
|
|
645 |
|
646 |
return p_x
|
647 |
|
|
|
648 |
def inverse_preprocess_gripper(self, grippers): # model frame (not data frame) to viewer frame
|
649 |
R = self.preprocess_metadata['R']
|
650 |
R_viewer = self.preprocess_metadata['R_viewer']
|
@@ -657,55 +673,20 @@ class DynamicsVisualizer:
|
|
657 |
|
658 |
return grippers
|
659 |
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
rgb = params['rgb_colors']
|
664 |
-
quat = torch.nn.functional.normalize(params['unnorm_rotations'])
|
665 |
-
opa = torch.sigmoid(params['logit_opacities'])
|
666 |
-
scales = torch.exp(params['log_scales'])
|
667 |
-
else:
|
668 |
-
assert isinstance(params, tuple)
|
669 |
-
xyz, rgb, quat, opa, scales = params
|
670 |
-
|
671 |
-
quat = torch.nn.functional.normalize(quat, dim=-1)
|
672 |
-
|
673 |
-
# transform
|
674 |
-
R = self.preprocess_metadata['R']
|
675 |
-
R_viewer = self.preprocess_metadata['R_viewer']
|
676 |
-
scale = self.preprocess_metadata['scale']
|
677 |
-
global_translation = self.preprocess_metadata['global_translation']
|
678 |
-
|
679 |
-
mat = quat2mat(quat)
|
680 |
-
mat = R @ mat
|
681 |
-
xyz = xyz @ R.T
|
682 |
-
xyz = xyz * scale
|
683 |
-
xyz += global_translation
|
684 |
-
quat = mat2quat(mat)
|
685 |
-
scales = scales * scale
|
686 |
-
|
687 |
-
# viewer-specific transform (flip y and z)
|
688 |
-
# model frame to viewer frame
|
689 |
-
xyz = xyz @ R_viewer.T
|
690 |
-
quat = mat2quat(R_viewer @ quat2mat(quat))
|
691 |
-
|
692 |
-
t_viewer = -xyz.mean(dim=0)
|
693 |
-
t_viewer[2] = 0
|
694 |
-
xyz += t_viewer
|
695 |
-
print('Overwriting t_viewer to be the planar mean of the object')
|
696 |
-
self.preprocess_metadata['t_viewer'] = t_viewer
|
697 |
-
|
698 |
-
if isinstance(params, dict):
|
699 |
-
params['means3D'] = xyz
|
700 |
-
params['rgb_colors'] = rgb
|
701 |
-
params['unnorm_rotations'] = quat
|
702 |
-
params['logit_opacities'] = opa
|
703 |
-
params['log_scales'] = torch.log(scales)
|
704 |
-
else:
|
705 |
-
params = xyz, rgb, quat, opa, scales
|
706 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
707 |
return params
|
708 |
|
|
|
709 |
def preprocess_gs(self, params):
|
710 |
if isinstance(params, dict):
|
711 |
xyz = params['means3D']
|
@@ -755,6 +736,7 @@ class DynamicsVisualizer:
|
|
755 |
|
756 |
return params
|
757 |
|
|
|
758 |
def preprocess_bg_gs(self):
|
759 |
t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params
|
760 |
g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params
|
@@ -802,6 +784,7 @@ class DynamicsVisualizer:
|
|
802 |
self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities
|
803 |
self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities
|
804 |
|
|
|
805 |
def update_rendervar(self, rendervar):
|
806 |
p_x = self.state['x']
|
807 |
p_x_viewer = self.inverse_preprocess_x(p_x)
|
@@ -878,6 +861,7 @@ class DynamicsVisualizer:
|
|
878 |
|
879 |
return rendervar, rendervar_full
|
880 |
|
|
|
881 |
def reset_state(self, params, visualize_image=False, init=False):
|
882 |
xyz_0 = params['means3D']
|
883 |
rgb_0 = params['rgb_colors']
|
@@ -945,12 +929,7 @@ class DynamicsVisualizer:
|
|
945 |
|
946 |
return rendervar_init
|
947 |
|
948 |
-
|
949 |
-
|
950 |
-
|
951 |
-
|
952 |
-
|
953 |
-
|
954 |
def reset(self):
|
955 |
params = self.preprocess_gs(self.params)
|
956 |
if self.with_bg:
|
@@ -992,6 +971,7 @@ class DynamicsVisualizer:
|
|
992 |
|
993 |
return form_video, form_3dgs_pred
|
994 |
|
|
|
995 |
def run_command(self, unit_command):
|
996 |
|
997 |
os.system('rm -rf ' + str(root / 'log/temp/*'))
|
@@ -1095,21 +1075,9 @@ class DynamicsVisualizer:
|
|
1095 |
|
1096 |
def on_click_run_zminus(self):
|
1097 |
return self.run_command([0, 0, -5.0])
|
1098 |
-
|
1099 |
-
def rotate(self, params, rot_mat):
|
1100 |
-
scale = np.linalg.norm(rot_mat, axis=1, keepdims=True)
|
1101 |
-
|
1102 |
-
params = {
|
1103 |
-
'means3D': pts,
|
1104 |
-
'rgb_colors': params['rgb_colors'],
|
1105 |
-
'log_scales': params['log_scales'],
|
1106 |
-
'unnorm_rotations': quats,
|
1107 |
-
'logit_opacities': params['logit_opacities'],
|
1108 |
-
}
|
1109 |
-
return params
|
1110 |
|
|
|
1111 |
def launch(self, share=False):
|
1112 |
-
|
1113 |
in_dir = root / 'log/gs/ckpts/rope_scene_1'
|
1114 |
batch_size = 1
|
1115 |
num_steps = 1
|
|
|
113 |
|
114 |
class DynamicsVisualizer:
|
115 |
|
116 |
+
@spaces.GPU
|
117 |
def __init__(self):
|
118 |
wp.init()
|
119 |
|
|
|
173 |
|
174 |
self.clear()
|
175 |
|
176 |
+
@spaces.GPU
|
177 |
def clear(self, clear_params=True):
|
178 |
self.metadata = {}
|
179 |
self.config = {}
|
|
|
206 |
self.material = None
|
207 |
self.friction = None
|
208 |
|
209 |
+
@spaces.GPU
|
210 |
def load_scaniverse(self, data_path):
|
211 |
|
212 |
### load splat params
|
|
|
303 |
# self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32)
|
304 |
self.state['gripper_radius'] = cfg.model.gripper_radius
|
305 |
|
306 |
+
@spaces.GPU
|
307 |
def load_params(self, params_path, remove_low_opa=True, remove_black=False):
|
308 |
pts, colors, scales, quats, opacities = read_splat(params_path)
|
309 |
|
|
|
369 |
self.state['clip_bound'] = torch.tensor([self.cfg.model.clip_bound], dtype=torch.float32)
|
370 |
self.state['enabled'] = torch.ones(n_particles, dtype=torch.bool)
|
371 |
|
372 |
+
@spaces.GPU
|
373 |
def set_camera(self, w, h, intr, w2c=None, R=None, t=None, near=0.01, far=100.0):
|
374 |
if w2c is None:
|
375 |
assert R is not None and t is not None
|
|
|
382 |
}
|
383 |
self.config = {'near': near, 'far': far}
|
384 |
|
385 |
+
@spaces.GPU
|
386 |
def load_eef(self, grippers=None, eef_t=None):
|
387 |
assert self.state['prev_key_pos'] is None
|
388 |
|
|
|
405 |
# self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32) + 0.001
|
406 |
self.state['gripper_radius'] = self.cfg.model.gripper_radius
|
407 |
|
408 |
+
@spaces.GPU
|
409 |
def load_preprocess_metadata(self, p_x_orig):
|
410 |
cfg = self.cfg
|
411 |
dx = cfg.sim.num_grids[-1]
|
|
|
442 |
'global_translation': global_translation,
|
443 |
}
|
444 |
|
445 |
+
@spaces.GPU
|
446 |
@torch.no_grad
|
447 |
def render(self, render_data, cam_id, bg=[0.7, 0.7, 0.7]):
|
448 |
render_data = {k: v.to(self.device) for k, v in render_data.items()}
|
|
|
452 |
im, _, depth, = GaussianRasterizer(raster_settings=cam)(**render_data)
|
453 |
return im, depth
|
454 |
|
455 |
+
@spaces.GPU
|
456 |
def knn_relations(self, bones):
|
457 |
k = self.k_rel
|
458 |
knn = NearestNeighbors(n_neighbors=k+1, algorithm='kd_tree').fit(bones.detach().cpu().numpy())
|
|
|
460 |
indices = indices[:, 1:] # exclude self
|
461 |
return indices
|
462 |
|
463 |
+
@spaces.GPU
|
464 |
def knn_weights_brute(self, bones, pts):
|
465 |
k = self.k_wgt
|
466 |
dist = torch.norm(pts[:, None] - bones, dim=-1) # (n_pts, n_bones)
|
|
|
473 |
weights_all[torch.arange(pts.shape[0])[:, None], indices] = weights
|
474 |
return weights_all
|
475 |
|
476 |
+
@spaces.GPU
|
477 |
def update_camera(self, k, w2c, w=None, h=None, near=0.01, far=100.0):
|
478 |
self.metadata['k'] = k
|
479 |
self.metadata['w2c'] = w2c
|
|
|
484 |
self.config['near'] = near
|
485 |
self.config['far'] = far
|
486 |
|
487 |
+
@spaces.GPU
|
488 |
def init_model(self, batch_size, num_steps, num_particles, ckpt_path=None):
|
489 |
self.cfg.sim.num_steps = num_steps
|
490 |
cfg = self.cfg
|
|
|
520 |
self.material = material
|
521 |
self.friction = friction
|
522 |
|
523 |
+
@spaces.GPU
|
524 |
def reload_model(self, num_steps): # only change num_steps
|
525 |
self.cfg.sim.num_steps = num_steps
|
526 |
sim = CacheDiffSimWithFrictionBatch(self.cfg, num_steps, 1, self.wp_device, requires_grad=True)
|
527 |
self.sim = sim
|
528 |
|
529 |
+
@spaces.GPU
|
530 |
@torch.no_grad
|
531 |
def step(self):
|
532 |
cfg = self.cfg
|
|
|
616 |
self.state['sub_pos'] = None
|
617 |
# self.state['sub_pos_timestamps'] = None
|
618 |
|
619 |
+
@spaces.GPU
|
620 |
def preprocess_x(self, p_x): # viewer frame to model frame (not data frame)
|
621 |
R = self.preprocess_metadata['R']
|
622 |
R_viewer = self.preprocess_metadata['R_viewer']
|
|
|
634 |
|
635 |
return p_x
|
636 |
|
637 |
+
@spaces.GPU
|
638 |
def preprocess_gripper(self, grippers): # viewer frame to model frame (not data frame)
|
639 |
R = self.preprocess_metadata['R']
|
640 |
R_viewer = self.preprocess_metadata['R_viewer']
|
|
|
647 |
|
648 |
return grippers
|
649 |
|
650 |
+
@spaces.GPU
|
651 |
def inverse_preprocess_x(self, p_x): # model frame (not data frame) to viewer frame
|
652 |
R = self.preprocess_metadata['R']
|
653 |
R_viewer = self.preprocess_metadata['R_viewer']
|
|
|
660 |
|
661 |
return p_x
|
662 |
|
663 |
+
@spaces.GPU
|
664 |
def inverse_preprocess_gripper(self, grippers): # model frame (not data frame) to viewer frame
|
665 |
R = self.preprocess_metadata['R']
|
666 |
R_viewer = self.preprocess_metadata['R_viewer']
|
|
|
673 |
|
674 |
return grippers
|
675 |
|
676 |
+
@spaces.GPU
|
677 |
+
def rotate(self, params, rot_mat):
|
678 |
+
scale = np.linalg.norm(rot_mat, axis=1, keepdims=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
679 |
|
680 |
+
params = {
|
681 |
+
'means3D': pts,
|
682 |
+
'rgb_colors': params['rgb_colors'],
|
683 |
+
'log_scales': params['log_scales'],
|
684 |
+
'unnorm_rotations': quats,
|
685 |
+
'logit_opacities': params['logit_opacities'],
|
686 |
+
}
|
687 |
return params
|
688 |
|
689 |
+
@spaces.GPU
|
690 |
def preprocess_gs(self, params):
|
691 |
if isinstance(params, dict):
|
692 |
xyz = params['means3D']
|
|
|
736 |
|
737 |
return params
|
738 |
|
739 |
+
@spaces.GPU
|
740 |
def preprocess_bg_gs(self):
|
741 |
t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params
|
742 |
g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params
|
|
|
784 |
self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities
|
785 |
self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities
|
786 |
|
787 |
+
@spaces.GPU
|
788 |
def update_rendervar(self, rendervar):
|
789 |
p_x = self.state['x']
|
790 |
p_x_viewer = self.inverse_preprocess_x(p_x)
|
|
|
861 |
|
862 |
return rendervar, rendervar_full
|
863 |
|
864 |
+
@spaces.GPU
|
865 |
def reset_state(self, params, visualize_image=False, init=False):
|
866 |
xyz_0 = params['means3D']
|
867 |
rgb_0 = params['rgb_colors']
|
|
|
929 |
|
930 |
return rendervar_init
|
931 |
|
932 |
+
@spaces.GPU
|
|
|
|
|
|
|
|
|
|
|
933 |
def reset(self):
|
934 |
params = self.preprocess_gs(self.params)
|
935 |
if self.with_bg:
|
|
|
971 |
|
972 |
return form_video, form_3dgs_pred
|
973 |
|
974 |
+
@spaces.GPU
|
975 |
def run_command(self, unit_command):
|
976 |
|
977 |
os.system('rm -rf ' + str(root / 'log/temp/*'))
|
|
|
1075 |
|
1076 |
def on_click_run_zminus(self):
|
1077 |
return self.run_command([0, 0, -5.0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1078 |
|
1079 |
+
@spaces.GPU
|
1080 |
def launch(self, share=False):
|
|
|
1081 |
in_dir = root / 'log/gs/ckpts/rope_scene_1'
|
1082 |
batch_size = 1
|
1083 |
num_steps = 1
|