kaifz commited on
Commit
966acd1
·
1 Parent(s): 18bc2d6

add sloth and axis

Browse files
app.py CHANGED
@@ -81,11 +81,9 @@ def fps(x, enabled, n, device, random_start=False):
81
 
82
  class DynamicsVisualizer:
83
 
84
- def __init__(self):
85
- self.width = 640
86
- self.height = 480
87
-
88
- best_models = {
89
  'cloth': ['cloth', 'train', 100000, [610, 650]],
90
  'rope': ['rope', 'train', 100000, [651, 691]],
91
  'paperbag': ['paperbag', 'train', 100000, [200, 220]],
@@ -93,16 +91,21 @@ class DynamicsVisualizer:
93
  'box': ['box', 'train', 100000, [306, 323]],
94
  'bread': ['bread', 'train', 100000, [143, 163]],
95
  }
96
-
97
  task_name = 'rope'
 
98
 
99
- with open(root / f'log/{best_models[task_name][0]}/{best_models[task_name][1]}/hydra.yaml', 'r') as f:
 
 
 
 
 
100
  config = yaml.load(f, Loader=yaml.CLoader)
101
  cfg = OmegaConf.create(config)
102
 
103
- cfg.iteration = best_models[task_name][2]
104
- cfg.start_episode = best_models[task_name][3][0]
105
- cfg.end_episode = best_models[task_name][3][1]
106
  cfg.sim.num_steps = 1000
107
  cfg.sim.gripper_forcing = False
108
  cfg.sim.uniform = True
@@ -258,83 +261,6 @@ class DynamicsVisualizer:
258
  self.state['prev_key_pos'] = grippers[:, :3] # (1, 3)
259
  # self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32)
260
  self.state['gripper_radius'] = cfg.model.gripper_radius
261
-
262
- def load_params(self, params_path, remove_low_opa=True, remove_black=False):
263
- pts, colors, scales, quats, opacities = read_splat(params_path)
264
-
265
- if remove_low_opa:
266
- low_opa_idx = opacities[:, 0] < 0.1
267
- pts = pts[~low_opa_idx]
268
- colors = colors[~low_opa_idx]
269
- quats = quats[~low_opa_idx]
270
- opacities = opacities[~low_opa_idx]
271
- scales = scales[~low_opa_idx]
272
-
273
- if remove_black:
274
- low_color_idx = colors.sum(axis=-1) < 0.5
275
- pts = pts[~low_color_idx]
276
- colors = colors[~low_color_idx]
277
- quats = quats[~low_color_idx]
278
- opacities = opacities[~low_color_idx]
279
- scales = scales[~low_color_idx]
280
-
281
- self.params = {
282
- 'means3D': torch.from_numpy(pts).to(torch.float32).to(self.device),
283
- 'rgb_colors': torch.from_numpy(colors).to(torch.float32).to(self.device),
284
- 'log_scales': torch.log(torch.from_numpy(scales).to(torch.float32).to(self.device)),
285
- 'unnorm_rotations': torch.from_numpy(quats).to(torch.float32).to(self.device),
286
- 'logit_opacities': torch.logit(torch.from_numpy(opacities).to(torch.float32).to(self.device))
287
- }
288
-
289
- table_splat = root / 'log/gs/ckpts/table.splat'
290
- sphere_splat = root / 'log/gs/ckpts/sphere.splat'
291
- gripper_splat = root / 'log/gs/ckpts/gripper.splat' # gripper_new.splat
292
-
293
- table_params = read_splat(table_splat) # numpy
294
-
295
- ## add table and gripper
296
- # add table
297
- t_pts, t_colors, t_scales, t_quats, t_opacities = table_params
298
- t_pts = torch.tensor(t_pts).to(torch.float32).to(self.device)
299
- t_colors = torch.tensor(t_colors).to(torch.float32).to(self.device)
300
- t_scales = torch.tensor(t_scales).to(torch.float32).to(self.device)
301
- t_quats = torch.tensor(t_quats).to(torch.float32).to(self.device)
302
- t_opacities = torch.tensor(t_opacities).to(torch.float32).to(self.device)
303
-
304
- # add table pos
305
- t_pts = t_pts + torch.tensor([0, 0, 0.02]).to(torch.float32).to(self.device)
306
-
307
- # add gripper
308
- gripper_params = read_splat(gripper_splat) # numpy
309
-
310
- g_pts, g_colors, g_scales, g_quats, g_opacities = gripper_params
311
- g_pts = torch.tensor(g_pts).to(torch.float32).to(self.device)
312
- g_colors = torch.tensor(g_colors).to(torch.float32).to(self.device)
313
- g_scales = torch.tensor(g_scales).to(torch.float32).to(self.device)
314
- g_quats = torch.tensor(g_quats).to(torch.float32).to(self.device)
315
- g_opacities = torch.tensor(g_opacities).to(torch.float32).to(self.device)
316
-
317
- # we do not do the gripper translation now because this will center the gripper in the data frame but not the viewer frame
318
-
319
- self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities # data frame
320
- self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities # data frame
321
-
322
- # load other info
323
- n_particles = self.cfg.sim.n_particles
324
- self.state['clip_bound'] = torch.tensor([self.cfg.model.clip_bound], dtype=torch.float32)
325
- self.state['enabled'] = torch.ones(n_particles, dtype=torch.bool)
326
-
327
- def set_camera(self, w, h, intr, w2c=None, R=None, t=None, near=0.01, far=100.0):
328
- if w2c is None:
329
- assert R is not None and t is not None
330
- w2c = Rt_to_w2c(R, t)
331
- self.metadata = {
332
- 'w': w,
333
- 'h': h,
334
- 'k': intr,
335
- 'w2c': w2c,
336
- }
337
- self.config = {'near': near, 'far': far}
338
 
339
  def load_eef(self, grippers=None, eef_t=None):
340
  assert self.state['prev_key_pos'] is None
@@ -453,7 +379,7 @@ class DynamicsVisualizer:
453
  self.colliders = colliders
454
 
455
  # load ckpt
456
- ckpt_path = root / 'log/rope/train/ckpt/100000.pt'
457
  ckpt = torch.load(ckpt_path, map_location=self.torch_device)
458
 
459
  material: nn.Module = PGNDModel(cfg)
@@ -704,6 +630,76 @@ class DynamicsVisualizer:
704
  t_pts = t_pts @ R_viewer.T
705
  t_quats = mat2quat(R_viewer @ quat2mat(t_quats))
706
  t_pts += t_viewer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
707
 
708
  g_mat = quat2mat(g_quats)
709
  g_mat = R @ g_mat
@@ -720,7 +716,13 @@ class DynamicsVisualizer:
720
  # TODO: center gripper in the viewer frame
721
  g_pts_tip = g_pts[g_pts_tip_mask]
722
  g_pts_tip_mean_xy = g_pts_tip[:, :2].mean(dim=0)
723
- g_pts_translation = torch.tensor([-g_pts_tip_mean_xy[0], -g_pts_tip_mean_xy[1], -0.23]).to(torch.float32).to(self.device)
 
 
 
 
 
 
724
  g_pts = g_pts + g_pts_translation
725
 
726
  self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities
@@ -882,7 +884,7 @@ class DynamicsVisualizer:
882
  center = (0, 0, 0.1)
883
  distance = 0.7
884
  elevation = 20
885
- azimuth = 180.0
886
  target = np.array(center)
887
  theta = 90 + azimuth
888
  z = distance * math.sin(math.radians(elevation))
@@ -930,8 +932,9 @@ class DynamicsVisualizer:
930
 
931
  return rendervar_init
932
 
933
- @spaces.GPU
934
- def reset(self):
 
935
  import warp as wp
936
  wp.init()
937
  gpus = [int(gpu) for gpu in self.cfg.gpus]
@@ -942,7 +945,7 @@ class DynamicsVisualizer:
942
  self.wp_device = wp_devices[0]
943
  self.torch_device = torch_devices[0]
944
 
945
- in_dir = root / 'log/gs/ckpts/rope_scene_1'
946
  batch_size = 1
947
  num_steps = 1
948
  num_particles = self.cfg.sim.n_particles
@@ -1028,7 +1031,7 @@ class DynamicsVisualizer:
1028
  center = (0, 0, 0.1)
1029
  distance = 0.7
1030
  elevation = 20
1031
- azimuth = 180.0
1032
  target = np.array(center)
1033
  theta = 90 + azimuth
1034
  z = distance * math.sin(math.radians(elevation))
@@ -1191,6 +1194,14 @@ class DynamicsVisualizer:
1191
  self.preprocess_metadata, self.state, self.params, \
1192
  self.table_params, self.gripper_params, rendervar
1193
 
 
 
 
 
 
 
 
 
1194
  @spaces.GPU
1195
  def on_click_run_xplus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar):
1196
  return self.run_command([5.0, 0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar)
@@ -1230,27 +1241,36 @@ class DynamicsVisualizer:
1230
 
1231
  with gr.Row():
1232
  gr.Markdown('### Project page: [https://kywind.github.io/pgnd](https://kywind.github.io/pgnd)')
1233
-
 
 
 
1234
  with gr.Row():
1235
  gr.Markdown(' '.join([
1236
- 'Instructions:\n',
1237
- '- Click the "Reset" button to initialize the simulation with the predicted video and Gaussian splats. Due to compute limitations of Huggingface Space, each run may take a prolonged period (up to 30 seconds).\n',
1238
  '- Use the buttons to move the gripper in the x, y, z directions. The gripper will move for a fixed length per click. The predicted video and Gaussian splats will be updated accordingly.\n',
1239
  '- X-Y plane is the table surface, and Z is the height.\n',
1240
  '- The predicted video from the previous step to the current step will be shown in the "Predicted video" section.\n',
1241
  '- The Gaussian splats after the current step will be shown in the "Predicted Gaussians" section.\n',
1242
- '- The simulation results may deviate from the initial shape due to accumulative prediction artifacts. Click the "Reset" button to reset the simulation state and reinitialize the predicted video and Gaussian splats.\n',
1243
  ]))
1244
-
 
 
 
 
 
 
 
 
 
 
 
 
 
1245
 
1246
  with gr.Row():
1247
 
1248
- # with gr.Column(scale=2):
1249
- # form_3dgs_orig = gr.Model3D(
1250
- # label='Original Gaussian Splats',
1251
- # value=None,
1252
- # )
1253
-
1254
  with gr.Column(scale=2):
1255
  form_video = gr.Video(
1256
  label='Predicted video',
@@ -1269,10 +1289,11 @@ class DynamicsVisualizer:
1269
  )
1270
 
1271
  # Layout
 
 
 
1272
  with gr.Row():
1273
  with gr.Column(scale=2):
1274
- with gr.Row():
1275
- run_reset = gr.Button("Reset")
1276
 
1277
  with gr.Row():
1278
  with gr.Column():
@@ -1294,61 +1315,15 @@ class DynamicsVisualizer:
1294
 
1295
  with gr.Column(scale=2):
1296
  _ = gr.Button(visible=False) # empty placeholder
1297
-
1298
- # with gr.Row():
1299
-
1300
- # # with gr.Column(scale=2):
1301
- # # form_3dgs_orig = gr.Model3D(
1302
- # # label='Original Gaussian Splats',
1303
- # # value=None,
1304
- # # )
1305
-
1306
- # with gr.Column(scale=2):
1307
- # form_video_2 = gr.Video(
1308
- # label='Predicted video',
1309
- # value=None,
1310
- # format='mp4',
1311
- # width=self.width,
1312
- # height=self.height,
1313
- # )
1314
-
1315
- # with gr.Column(scale=2):
1316
- # form_3dgs_pred_2 = gr.Model3D(
1317
- # label='Predicted Gaussians',
1318
- # height=self.height,
1319
- # value=None,
1320
- # clear_color=[0, 0, 0, 0],
1321
- # )
1322
-
1323
- # # Layout
1324
- # with gr.Row():
1325
- # with gr.Column(scale=2):
1326
- # with gr.Row():
1327
- # run_reset_2 = gr.Button("Reset")
1328
-
1329
- # with gr.Row():
1330
- # with gr.Column():
1331
- # run_xminus_2 = gr.Button("x-")
1332
- # with gr.Column():
1333
- # run_xplus_2 = gr.Button("x+")
1334
-
1335
- # with gr.Row():
1336
- # with gr.Column():
1337
- # run_yminus_2 = gr.Button("y-")
1338
- # with gr.Column():
1339
- # run_yplus_2 = gr.Button("y+")
1340
-
1341
- # with gr.Row():
1342
- # with gr.Column():
1343
- # run_zminus_2 = gr.Button("z-")
1344
- # with gr.Column():
1345
- # run_zplus_2 = gr.Button("z+")
1346
-
1347
- # with gr.Column(scale=2):
1348
- # _ = gr.Button(visible=False) # empty placeholder
1349
 
1350
  # Set up callbacks
1351
- run_reset.click(self.reset,
 
 
 
 
 
 
1352
  inputs=[],
1353
  outputs=[form_video, form_3dgs_pred,
1354
  preprocess_metadata, state, params,
@@ -1396,35 +1371,6 @@ class DynamicsVisualizer:
1396
  preprocess_metadata, state, params,
1397
  table_params, gripper_params, rendervar])
1398
 
1399
- # Set up callbacks
1400
- # run_reset_2.click(self.reset_2,
1401
- # inputs=[],
1402
- # outputs=[form_video_2, form_3dgs_pred_2])
1403
-
1404
- # run_xplus_2.click(self.on_click_run_xplus_2,
1405
- # inputs=[],
1406
- # outputs=[form_video_2, form_3dgs_pred_2])
1407
-
1408
- # run_xminus_2.click(self.on_click_run_xminus_2,
1409
- # inputs=[],
1410
- # outputs=[form_video_2, form_3dgs_pred_2])
1411
-
1412
- # run_yplus_2.click(self.on_click_run_yplus_2,
1413
- # inputs=[],
1414
- # outputs=[form_video_2, form_3dgs_pred_2])
1415
-
1416
- # run_yminus_2.click(self.on_click_run_yminus_2,
1417
- # inputs=[],
1418
- # outputs=[form_video_2, form_3dgs_pred_2])
1419
-
1420
- # run_zplus_2.click(self.on_click_run_zplus_2,
1421
- # inputs=[],
1422
- # outputs=[form_video_2, form_3dgs_pred_2])
1423
-
1424
- # run_zminus_2.click(self.on_click_run_zminus_2,
1425
- # inputs=[],
1426
- # outputs=[form_video_2, form_3dgs_pred_2])
1427
-
1428
  app.launch(share=share)
1429
 
1430
 
 
81
 
82
  class DynamicsVisualizer:
83
 
84
+ def __init__(self, wp_device='cuda', torch_device='cuda'):
85
+
86
+ self.best_models = {
 
 
87
  'cloth': ['cloth', 'train', 100000, [610, 650]],
88
  'rope': ['rope', 'train', 100000, [651, 691]],
89
  'paperbag': ['paperbag', 'train', 100000, [200, 220]],
 
91
  'box': ['box', 'train', 100000, [306, 323]],
92
  'bread': ['bread', 'train', 100000, [143, 163]],
93
  }
 
94
  task_name = 'rope'
95
+ self.init(task_name)
96
 
97
+ def init(self, task_name):
98
+ self.width = 640
99
+ self.height = 480
100
+ self.task_name = task_name
101
+
102
+ with open(root / f'log/{self.best_models[task_name][0]}/{self.best_models[task_name][1]}/hydra.yaml', 'r') as f:
103
  config = yaml.load(f, Loader=yaml.CLoader)
104
  cfg = OmegaConf.create(config)
105
 
106
+ cfg.iteration = self.best_models[task_name][2]
107
+ cfg.start_episode = self.best_models[task_name][3][0]
108
+ cfg.end_episode = self.best_models[task_name][3][1]
109
  cfg.sim.num_steps = 1000
110
  cfg.sim.gripper_forcing = False
111
  cfg.sim.uniform = True
 
261
  self.state['prev_key_pos'] = grippers[:, :3] # (1, 3)
262
  # self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32)
263
  self.state['gripper_radius'] = cfg.model.gripper_radius
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
  def load_eef(self, grippers=None, eef_t=None):
266
  assert self.state['prev_key_pos'] is None
 
379
  self.colliders = colliders
380
 
381
  # load ckpt
382
+ ckpt_path = root / f'log/{self.task_name}/train/ckpt/100000.pt'
383
  ckpt = torch.load(ckpt_path, map_location=self.torch_device)
384
 
385
  material: nn.Module = PGNDModel(cfg)
 
630
  t_pts = t_pts @ R_viewer.T
631
  t_quats = mat2quat(R_viewer @ quat2mat(t_quats))
632
  t_pts += t_viewer
633
+
634
+ axes = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
635
+ dirs = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] # x, y, z axes
636
+ for ee in range(3):
637
+ gripper_direction = torch.tensor(dirs[ee], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3)
638
+ gripper_direction = gripper_direction / (torch.norm(gripper_direction, dim=-1, keepdim=True) + 1e-10) # normalize
639
+
640
+ R = self.preprocess_metadata['R']
641
+ # model frame to data frame
642
+ direction = gripper_direction @ R.T
643
+
644
+ n_grippers = 1
645
+ N = 200
646
+ length = 0.2
647
+ kk = 5
648
+ xyz_test = torch.zeros((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=t_pts.dtype)
649
+
650
+ if self.task_name == 'rope':
651
+ pos = torch.tensor([0.0, 0.0, 1.2], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3) # gripper position in model frame
652
+ else:
653
+ pos = torch.tensor([1.2, 0.0, 0.7], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3)
654
+ gripper_now_inv_xyz = self.inverse_preprocess_gripper(pos)
655
+ gripper_now_inv_rot = torch.eye(3, device=self.torch_device).unsqueeze(0).repeat(n_grippers, 1, 1)
656
+
657
+ center_point = torch.tensor([0.0, 0.0, 0.10], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3) # center point in gripper frame
658
+ gripper_center_inv_xyz = gripper_now_inv_xyz + \
659
+ torch.einsum('ijk,ik->ij', gripper_now_inv_rot, center_point) # (n_grippers, 3)
660
+
661
+ for i in range(N):
662
+ offset = i / N * length * direction
663
+ xyz_test[:, i] = gripper_center_inv_xyz + offset
664
+
665
+ if direction[0, 2] < 0.9 and direction[0, 2] > -0.9: # not vertical
666
+ direction_up = -direction + torch.tensor([0.0, 0.0, 0.5], device=self.torch_device, dtype=t_pts.dtype)
667
+ direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) # normalize
668
+ direction_down = -direction + torch.tensor([0.0, 0.0, -0.5], device=self.torch_device, dtype=t_pts.dtype)
669
+ direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) # normalize
670
+ else:
671
+ direction_up = -direction + torch.tensor([0.0, 0.5, 0.0], device=self.torch_device, dtype=t_pts.dtype)
672
+ direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) # normalize
673
+ direction_down = -direction + torch.tensor([0.0, -0.5, 0.0], device=self.torch_device, dtype=t_pts.dtype)
674
+ direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) # normalize
675
+
676
+ for i in range(N, N + N // kk):
677
+ offset = length * direction + (i - N) / N * length * direction_up
678
+ xyz_test[:, i] = gripper_center_inv_xyz + offset
679
+
680
+ for i in range(N + N // kk, N + N // kk + N // kk):
681
+ offset = length * direction + (i - N - N // kk) / N * length * direction_down
682
+ xyz_test[:, i] = gripper_center_inv_xyz + offset
683
+
684
+ color_test = torch.zeros_like(xyz_test, device=self.torch_device, dtype=t_pts.dtype)
685
+ color_test[:, :, 0] = axes[ee][0]
686
+ color_test[:, :, 1] = axes[ee][1]
687
+ color_test[:, :, 2] = axes[ee][2]
688
+ quat_test = torch.zeros((n_grippers, N + N // kk + N // kk, 4), device=self.torch_device, dtype=t_pts.dtype)
689
+ quat_test[:, :, 0] = 1.0 # identity quaternion
690
+ opa_test = torch.ones((n_grippers, N + N // kk + N // kk, 1), device=self.torch_device, dtype=t_pts.dtype)
691
+ scales_test = torch.ones((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=t_pts.dtype) * 0.002
692
+
693
+ t_pts = torch.cat([t_pts, xyz_test.reshape(-1, 3)], dim=0)
694
+ t_colors = torch.cat([t_colors, color_test.reshape(-1, 3)], dim=0)
695
+ t_quats = torch.cat([t_quats, quat_test.reshape(-1, 4)], dim=0)
696
+ t_opacities = torch.cat([t_opacities, opa_test.reshape(-1, 1)], dim=0)
697
+ t_scales = torch.cat([t_scales, scales_test.reshape(-1, 3)], dim=0)
698
+
699
+ t_pts = t_pts.reshape(-1, 3)
700
+ t_colors = t_colors.reshape(-1, 3)
701
+ t_quats = t_quats.reshape(-1, 4)
702
+ t_opacities = t_opacities.reshape(-1, 1)
703
 
704
  g_mat = quat2mat(g_quats)
705
  g_mat = R @ g_mat
 
716
  # TODO: center gripper in the viewer frame
717
  g_pts_tip = g_pts[g_pts_tip_mask]
718
  g_pts_tip_mean_xy = g_pts_tip[:, :2].mean(dim=0)
719
+
720
+ if self.task_name == 'rope':
721
+ g_pts_translation = torch.tensor([-g_pts_tip_mean_xy[0], -g_pts_tip_mean_xy[1], -0.23]).to(torch.float32).to(self.device)
722
+ elif self.task_name == 'sloth':
723
+ g_pts_translation = torch.tensor([-g_pts_tip_mean_xy[0], -g_pts_tip_mean_xy[1], -0.32]).to(torch.float32).to(self.device)
724
+ else:
725
+ raise NotImplementedError(f"Task {self.task_name} not implemented for gripper translation.")
726
  g_pts = g_pts + g_pts_translation
727
 
728
  self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities
 
884
  center = (0, 0, 0.1)
885
  distance = 0.7
886
  elevation = 20
887
+ azimuth = 180.0 if self.task_name == 'rope' else 120.0
888
  target = np.array(center)
889
  theta = 90 + azimuth
890
  z = distance * math.sin(math.radians(elevation))
 
932
 
933
  return rendervar_init
934
 
935
+ def reset(self, task_name, scene_name):
936
+ self.init(task_name)
937
+
938
  import warp as wp
939
  wp.init()
940
  gpus = [int(gpu) for gpu in self.cfg.gpus]
 
945
  self.wp_device = wp_devices[0]
946
  self.torch_device = torch_devices[0]
947
 
948
+ in_dir = root / f'log/gs/ckpts/{scene_name}'
949
  batch_size = 1
950
  num_steps = 1
951
  num_particles = self.cfg.sim.n_particles
 
1031
  center = (0, 0, 0.1)
1032
  distance = 0.7
1033
  elevation = 20
1034
+ azimuth = 180.0 if self.task_name == 'rope' else 120.0
1035
  target = np.array(center)
1036
  theta = 90 + azimuth
1037
  z = distance * math.sin(math.radians(elevation))
 
1194
  self.preprocess_metadata, self.state, self.params, \
1195
  self.table_params, self.gripper_params, rendervar
1196
 
1197
+ @spaces.GPU
1198
+ def reset_rope(self):
1199
+ return self.reset('rope', 'rope_scene_1')
1200
+
1201
+ @spaces.GPU
1202
+ def reset_plush(self):
1203
+ return self.reset('sloth', 'sloth_scene_1')
1204
+
1205
  @spaces.GPU
1206
  def on_click_run_xplus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar):
1207
  return self.run_command([5.0, 0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar)
 
1241
 
1242
  with gr.Row():
1243
  gr.Markdown('### Project page: [https://kywind.github.io/pgnd](https://kywind.github.io/pgnd)')
1244
+
1245
+ with gr.Row():
1246
+ gr.Markdown('### Instructions:')
1247
+
1248
  with gr.Row():
1249
  gr.Markdown(' '.join([
1250
+ '- Click the "Reset-\<object\>" button to initialize the simulation with the predicted video and Gaussian splats. Due to compute limitations of Huggingface Space, each run may take a prolonged period (up to 30 seconds).\n',
 
1251
  '- Use the buttons to move the gripper in the x, y, z directions. The gripper will move for a fixed length per click. The predicted video and Gaussian splats will be updated accordingly.\n',
1252
  '- X-Y plane is the table surface, and Z is the height.\n',
1253
  '- The predicted video from the previous step to the current step will be shown in the "Predicted video" section.\n',
1254
  '- The Gaussian splats after the current step will be shown in the "Predicted Gaussians" section.\n',
1255
+ '- The simulation results may deviate from the initial shape due to accumulative prediction artifacts. Click the Reset button to reset the simulation state and reinitialize the predicted video and Gaussian splats.\n',
1256
  ]))
1257
+
1258
+ with gr.Row():
1259
+ gr.Markdown('### Select a scene to reset the simulation:')
1260
+
1261
+ with gr.Row():
1262
+ with gr.Column(scale=2):
1263
+ with gr.Row():
1264
+ with gr.Column():
1265
+ run_reset_plush = gr.Button("Reset - Plush")
1266
+ with gr.Column():
1267
+ run_reset_rope = gr.Button("Reset - Rope")
1268
+
1269
+ with gr.Column(scale=2):
1270
+ _ = gr.Button(visible=False) # empty placeholder
1271
 
1272
  with gr.Row():
1273
 
 
 
 
 
 
 
1274
  with gr.Column(scale=2):
1275
  form_video = gr.Video(
1276
  label='Predicted video',
 
1289
  )
1290
 
1291
  # Layout
1292
+ with gr.Row():
1293
+ gr.Markdown('### Control the gripper to move in the x, y, z directions:')
1294
+
1295
  with gr.Row():
1296
  with gr.Column(scale=2):
 
 
1297
 
1298
  with gr.Row():
1299
  with gr.Column():
 
1315
 
1316
  with gr.Column(scale=2):
1317
  _ = gr.Button(visible=False) # empty placeholder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1318
 
1319
  # Set up callbacks
1320
+ run_reset_rope.click(self.reset_rope,
1321
+ inputs=[],
1322
+ outputs=[form_video, form_3dgs_pred,
1323
+ preprocess_metadata, state, params,
1324
+ table_params, gripper_params, rendervar])
1325
+
1326
+ run_reset_plush.click(self.reset_plush,
1327
  inputs=[],
1328
  outputs=[form_video, form_3dgs_pred,
1329
  preprocess_metadata, state, params,
 
1371
  preprocess_metadata, state, params,
1372
  table_params, gripper_params, rendervar])
1373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1374
  app.launch(share=share)
1375
 
1376
 
src/experiments/log/gs/ckpts/sloth_scene_1/eef_xyz.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -1.1209631912410259247e-02
2
+ -0.754003190994262695e-02
3
+ -1.769125705957412720e-01
src/experiments/log/gs/ckpts/sloth_scene_1/eef_xyz_old.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -4.266144707798957825e-03
2
+ -6.183005869388580322e-02
3
+ -1.841607391834259033e-01
src/experiments/log/gs/ckpts/sloth_scene_1/gripper.splat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:893d6c2608a022685ab7cf8d044f044afe6eda3248d2baf1c1ac6d55160f1041
3
+ size 1151264
src/experiments/log/gs/ckpts/sloth_scene_1/gripper_old.splat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:def7c4ecddf10a491a3717bbc271ab55b9ab35437452cf6d61666fa2ccbd7883
3
+ size 1212288
src/experiments/log/gs/ckpts/sloth_scene_1/object.splat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94110d55ce1ba8dd3021f70b3864b7b833f22cf94757b96640d8f9c74f0a2c1c
3
+ size 4481248
src/experiments/log/gs/ckpts/sloth_scene_1/table.splat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed2090607151595c8b4b7ec44ffde9196366c2201b3bbd21e03bf924cf54e29c
3
+ size 7051104
src/experiments/log/gs/temp/form_video.mp4 CHANGED
Binary files a/src/experiments/log/gs/temp/form_video.mp4 and b/src/experiments/log/gs/temp/form_video.mp4 differ
 
src/experiments/log/gs/temp/form_video_init.mp4 CHANGED
Binary files a/src/experiments/log/gs/temp/form_video_init.mp4 and b/src/experiments/log/gs/temp/form_video_init.mp4 differ
 
src/experiments/log/gs/temp/gs_pred.splat CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:753a59d8cb6fb82e29d233db6865e6c6a87fabc5b469619e4310d3fbee619f6c
3
- size 7684352
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4d1d42fd7779673768a2604429439aba5c0228c08350d4c0173d6f7cce89a293
3
+ size 12719456
src/experiments/log/sloth/train/ckpt/100000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ce7f86a40058c2680784ac40f633a67e00e9ce8af8a6111acc3362d71d3b052
3
+ size 3374922
src/experiments/log/sloth/train/hydra.yaml ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ material:
3
+ requires_grad: true
4
+ output_scale: 1.0
5
+ input_scale: 2.0
6
+ radius: 0.2
7
+ absolute_y: false
8
+ pe_num_func_res: 0
9
+ friction:
10
+ value: 0.0
11
+ requires_grad: false
12
+ ckpt: null
13
+ clip_bound: 1.5
14
+ eef_t:
15
+ - 0.0
16
+ - 0.0
17
+ - 0.01
18
+ gripper_radius: 0.04
19
+ render:
20
+ width: 512
21
+ height: 512
22
+ skip_frame: 1
23
+ bound: 1.5
24
+ fps: 5
25
+ radius_scale: 500
26
+ center:
27
+ - 0.5
28
+ - 0.3
29
+ - 0.5
30
+ distance: 1.4
31
+ azimuth: -125
32
+ elevation: 30
33
+ reflectance:
34
+ - 0.92941176
35
+ - 0.32941176
36
+ - 0.23137255
37
+ sim:
38
+ num_steps_train: 5
39
+ num_steps: 1000
40
+ interval: 1
41
+ num_grids:
42
+ - 50
43
+ - 50
44
+ - 50
45
+ - 0.02
46
+ dt: 0.1
47
+ bound: 3
48
+ eps: 1.0e-07
49
+ skip_frame: 1
50
+ num_grippers: 1
51
+ preprocess_scale: 1.0
52
+ preprocess_with_table: true
53
+ n_particles: 1000
54
+ gripper_forcing: true
55
+ gripper_points: false
56
+ n_history: 2
57
+ uniform: false
58
+ train:
59
+ name: sloth/train
60
+ dataset_name: sloth/dataset
61
+ source_dataset_name: data/sloth_merged/sub_episodes_v
62
+ num_iterations: 100000
63
+ resume_iteration: 0
64
+ batch_size: 32
65
+ num_workers: 8
66
+ material_lr: 0.0001
67
+ material_wd: 0.0
68
+ material_grad_max_norm: 0.1
69
+ training_start_episode: 0
70
+ training_end_episode: 113
71
+ eval_start_episode: 113
72
+ eval_end_episode: 133
73
+ iteration_log_interval: 10
74
+ iteration_save_interval: 1000
75
+ iteration_eval_interval: 10000
76
+ loss_factor: 1.0
77
+ loss_factor_v: 0.0
78
+ loss_factor_x: 100.0
79
+ friction_lr: 0.1
80
+ friction_wd: 0.0
81
+ friction_grad_max_norm: 0.1
82
+ dataset_load_skip_frame: 3
83
+ dataset_skip_frame: 1
84
+ downsample: false
85
+ use_pv: false
86
+ use_gs: false
87
+ dataset_non_overwrite: true
88
+ seed: 0
89
+ cpu: 0
90
+ num_cpus: 128
91
+ gpus:
92
+ - 0
93
+ overwrite: false
94
+ resume: true
95
+ debug: false
src/experiments/log/temp/0000.png CHANGED

Git LFS Details

  • SHA256: 3dfd0c59aa71264b19a4124173880ea030768039594554b21eaeaa5c4264d0b2
  • Pointer size: 131 Bytes
  • Size of remote file: 294 kB

Git LFS Details

  • SHA256: 90ea22785c22d7b9efa6a49007a98b481526c1ee0b558f89f38b4776c1e91bff
  • Pointer size: 131 Bytes
  • Size of remote file: 329 kB
src/experiments/log/temp/0001.png CHANGED

Git LFS Details

  • SHA256: ea86d373cc91a288fa3fd16b672135a539864a78bdaeff7cd97557f33b38ae25
  • Pointer size: 131 Bytes
  • Size of remote file: 294 kB

Git LFS Details

  • SHA256: 661e57f7d763f6ad9c09a88b63f66b217606d1eaa1f17c1070fc88765319ffb6
  • Pointer size: 131 Bytes
  • Size of remote file: 329 kB
src/experiments/log/temp/0002.png CHANGED

Git LFS Details

  • SHA256: 00b31a90f8faa0ee5ba954e603fe6fbfcb2d2c92ba2e67ac741b41b38aa19510
  • Pointer size: 131 Bytes
  • Size of remote file: 294 kB

Git LFS Details

  • SHA256: 34abacb6adcec8d18285d74ba97dfff15cad3cc07c3788def04825ffc0ab8182
  • Pointer size: 131 Bytes
  • Size of remote file: 329 kB
src/experiments/log/temp/0003.png CHANGED

Git LFS Details

  • SHA256: 31fa97cfdb89e970548d6828ff615cc5e4a42b1f4115cf87ef2fec9cec9c9074
  • Pointer size: 131 Bytes
  • Size of remote file: 293 kB

Git LFS Details

  • SHA256: 5acac60f2f1f09f95d8c3969b4f4214d98f06f922a032fce5a98c4e600d0b44e
  • Pointer size: 131 Bytes
  • Size of remote file: 328 kB
src/experiments/log/temp/0004.png CHANGED

Git LFS Details

  • SHA256: 550b4d0abf1702dab6f1bac736643638d3e0e8e32dbb8f3fa0165dafdf10252e
  • Pointer size: 131 Bytes
  • Size of remote file: 293 kB

Git LFS Details

  • SHA256: 882986cf1a5814c2e325a1f377f2fbcc96dcd8228f7ec6f27c625956ee9b357a
  • Pointer size: 131 Bytes
  • Size of remote file: 328 kB
src/experiments/log/temp/0005.png CHANGED

Git LFS Details

  • SHA256: b06780e0198f3e4186ea4a487eb1c9ed80765b4db9167ceb636247a5e2a3fd90
  • Pointer size: 131 Bytes
  • Size of remote file: 293 kB

Git LFS Details

  • SHA256: 16bfdd965551f1a00f088ba905c80e75563b3c03a07078acb068022d08917c55
  • Pointer size: 131 Bytes
  • Size of remote file: 328 kB
src/experiments/log/temp/0006.png CHANGED

Git LFS Details

  • SHA256: 35c96f984bc0f5343aa8cca6450c0c832871eb2af03d39764463791f11615759
  • Pointer size: 131 Bytes
  • Size of remote file: 293 kB

Git LFS Details

  • SHA256: bed5cd0b790f8eb3d7c82e8b18751fe11bae743c53133d6d1004e1179c841ad9
  • Pointer size: 131 Bytes
  • Size of remote file: 328 kB
src/experiments/log/temp/0007.png CHANGED

Git LFS Details

  • SHA256: 27093820fe6c46d48fb4780a39e862d45c8b6c424b4f3a51a4ede6cb7f9019c1
  • Pointer size: 131 Bytes
  • Size of remote file: 292 kB

Git LFS Details

  • SHA256: 86c5c09e3a45d6517d6b4fffcdf6497ee4b99c705a37863daff704184284f70d
  • Pointer size: 131 Bytes
  • Size of remote file: 327 kB
src/experiments/log/temp/0008.png CHANGED

Git LFS Details

  • SHA256: 92b8b876cf511e7806c8f15f336599c0131e98b913e20682353d44d21238e633
  • Pointer size: 131 Bytes
  • Size of remote file: 292 kB

Git LFS Details

  • SHA256: a27a6d348c6e7668abb0c098614ad1fc3a005e49901c540170c7eca27e5ba076
  • Pointer size: 131 Bytes
  • Size of remote file: 327 kB
src/experiments/log/temp/0009.png CHANGED

Git LFS Details

  • SHA256: 238f779aa6aaea8ced58124ea9b9a05f95a455d564c05a09d1d5b8e87a143273
  • Pointer size: 131 Bytes
  • Size of remote file: 292 kB

Git LFS Details

  • SHA256: 8ce52babcbbd114b1d9b1f013a740dc350f79d0ce10d45fb7965d510e284736b
  • Pointer size: 131 Bytes
  • Size of remote file: 327 kB
src/experiments/log/temp/0010.png CHANGED

Git LFS Details

  • SHA256: 233e20259659e1934c205eb2fdea079cfa77402deb55183b728982bf4284240e
  • Pointer size: 131 Bytes
  • Size of remote file: 291 kB

Git LFS Details

  • SHA256: 6010b6a8e94df4844b4d48ad996c678d9b83aa70ead30c54e7c4e8d8523533fd
  • Pointer size: 131 Bytes
  • Size of remote file: 327 kB
src/experiments/log/temp/0011.png CHANGED

Git LFS Details

  • SHA256: 9b5c3dc6ae991e27e02ba8b8d06d5cfae06b9f62588f18fedaaacbf38a1cc8ba
  • Pointer size: 131 Bytes
  • Size of remote file: 291 kB

Git LFS Details

  • SHA256: 178a637de95777c32ac71ae0da5e070eda53db0ab2f7b1de1b38dae9fc3a3a69
  • Pointer size: 131 Bytes
  • Size of remote file: 327 kB
src/experiments/log/temp/0012.png CHANGED

Git LFS Details

  • SHA256: 022b5a63cd387b55a8639430270e54378562d0212d7bd8cd847251feeea62c03
  • Pointer size: 131 Bytes
  • Size of remote file: 291 kB

Git LFS Details

  • SHA256: 62063c449667f31e146434afd434768bdb628aa7ce063ad1a7de67adfefafd20
  • Pointer size: 131 Bytes
  • Size of remote file: 326 kB
src/experiments/log/temp/0013.png CHANGED

Git LFS Details

  • SHA256: 52cb92825263f5314ee9a983e9d60af223187c3f110e01cacb10274d704d0e2d
  • Pointer size: 131 Bytes
  • Size of remote file: 290 kB

Git LFS Details

  • SHA256: 96a131261023e94a2a1dd2143050ca459fc18ed1e91bc0dd30a2ef13f897493d
  • Pointer size: 131 Bytes
  • Size of remote file: 326 kB
src/experiments/log/temp/0014.png CHANGED

Git LFS Details

  • SHA256: 104884ddd7435ff788d2ecd7b28cbb6702bafa2765abc9c4f28e2d0301f39276
  • Pointer size: 131 Bytes
  • Size of remote file: 290 kB

Git LFS Details

  • SHA256: 26a9921a291c7308e87b7b5f981783002a5eb64d44b87d982b32ec3b1bb04963
  • Pointer size: 131 Bytes
  • Size of remote file: 326 kB
src/experiments/log/temp_init/0000.png CHANGED

Git LFS Details

  • SHA256: 0297a677a60b72059fe8f4f9c706bbf51f9f988b149dccdb4eb7a131f16dd8b4
  • Pointer size: 131 Bytes
  • Size of remote file: 290 kB

Git LFS Details

  • SHA256: adbe548b23c27ba54dad6354d2a71b3fb740a94bf3c0d66cd0054e78f239d988
  • Pointer size: 131 Bytes
  • Size of remote file: 329 kB