add sloth and axis
Browse files- app.py +135 -189
- src/experiments/log/gs/ckpts/sloth_scene_1/eef_xyz.txt +3 -0
- src/experiments/log/gs/ckpts/sloth_scene_1/eef_xyz_old.txt +3 -0
- src/experiments/log/gs/ckpts/sloth_scene_1/gripper.splat +3 -0
- src/experiments/log/gs/ckpts/sloth_scene_1/gripper_old.splat +3 -0
- src/experiments/log/gs/ckpts/sloth_scene_1/object.splat +3 -0
- src/experiments/log/gs/ckpts/sloth_scene_1/table.splat +3 -0
- src/experiments/log/gs/temp/form_video.mp4 +0 -0
- src/experiments/log/gs/temp/form_video_init.mp4 +0 -0
- src/experiments/log/gs/temp/gs_pred.splat +2 -2
- src/experiments/log/sloth/train/ckpt/100000.pt +3 -0
- src/experiments/log/sloth/train/hydra.yaml +95 -0
- src/experiments/log/temp/0000.png +2 -2
- src/experiments/log/temp/0001.png +2 -2
- src/experiments/log/temp/0002.png +2 -2
- src/experiments/log/temp/0003.png +2 -2
- src/experiments/log/temp/0004.png +2 -2
- src/experiments/log/temp/0005.png +2 -2
- src/experiments/log/temp/0006.png +2 -2
- src/experiments/log/temp/0007.png +2 -2
- src/experiments/log/temp/0008.png +2 -2
- src/experiments/log/temp/0009.png +2 -2
- src/experiments/log/temp/0010.png +2 -2
- src/experiments/log/temp/0011.png +2 -2
- src/experiments/log/temp/0012.png +2 -2
- src/experiments/log/temp/0013.png +2 -2
- src/experiments/log/temp/0014.png +2 -2
- src/experiments/log/temp_init/0000.png +2 -2
app.py
CHANGED
@@ -81,11 +81,9 @@ def fps(x, enabled, n, device, random_start=False):
|
|
81 |
|
82 |
class DynamicsVisualizer:
|
83 |
|
84 |
-
def __init__(self):
|
85 |
-
|
86 |
-
self.
|
87 |
-
|
88 |
-
best_models = {
|
89 |
'cloth': ['cloth', 'train', 100000, [610, 650]],
|
90 |
'rope': ['rope', 'train', 100000, [651, 691]],
|
91 |
'paperbag': ['paperbag', 'train', 100000, [200, 220]],
|
@@ -93,16 +91,21 @@ class DynamicsVisualizer:
|
|
93 |
'box': ['box', 'train', 100000, [306, 323]],
|
94 |
'bread': ['bread', 'train', 100000, [143, 163]],
|
95 |
}
|
96 |
-
|
97 |
task_name = 'rope'
|
|
|
98 |
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
100 |
config = yaml.load(f, Loader=yaml.CLoader)
|
101 |
cfg = OmegaConf.create(config)
|
102 |
|
103 |
-
cfg.iteration = best_models[task_name][2]
|
104 |
-
cfg.start_episode = best_models[task_name][3][0]
|
105 |
-
cfg.end_episode = best_models[task_name][3][1]
|
106 |
cfg.sim.num_steps = 1000
|
107 |
cfg.sim.gripper_forcing = False
|
108 |
cfg.sim.uniform = True
|
@@ -258,83 +261,6 @@ class DynamicsVisualizer:
|
|
258 |
self.state['prev_key_pos'] = grippers[:, :3] # (1, 3)
|
259 |
# self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32)
|
260 |
self.state['gripper_radius'] = cfg.model.gripper_radius
|
261 |
-
|
262 |
-
def load_params(self, params_path, remove_low_opa=True, remove_black=False):
|
263 |
-
pts, colors, scales, quats, opacities = read_splat(params_path)
|
264 |
-
|
265 |
-
if remove_low_opa:
|
266 |
-
low_opa_idx = opacities[:, 0] < 0.1
|
267 |
-
pts = pts[~low_opa_idx]
|
268 |
-
colors = colors[~low_opa_idx]
|
269 |
-
quats = quats[~low_opa_idx]
|
270 |
-
opacities = opacities[~low_opa_idx]
|
271 |
-
scales = scales[~low_opa_idx]
|
272 |
-
|
273 |
-
if remove_black:
|
274 |
-
low_color_idx = colors.sum(axis=-1) < 0.5
|
275 |
-
pts = pts[~low_color_idx]
|
276 |
-
colors = colors[~low_color_idx]
|
277 |
-
quats = quats[~low_color_idx]
|
278 |
-
opacities = opacities[~low_color_idx]
|
279 |
-
scales = scales[~low_color_idx]
|
280 |
-
|
281 |
-
self.params = {
|
282 |
-
'means3D': torch.from_numpy(pts).to(torch.float32).to(self.device),
|
283 |
-
'rgb_colors': torch.from_numpy(colors).to(torch.float32).to(self.device),
|
284 |
-
'log_scales': torch.log(torch.from_numpy(scales).to(torch.float32).to(self.device)),
|
285 |
-
'unnorm_rotations': torch.from_numpy(quats).to(torch.float32).to(self.device),
|
286 |
-
'logit_opacities': torch.logit(torch.from_numpy(opacities).to(torch.float32).to(self.device))
|
287 |
-
}
|
288 |
-
|
289 |
-
table_splat = root / 'log/gs/ckpts/table.splat'
|
290 |
-
sphere_splat = root / 'log/gs/ckpts/sphere.splat'
|
291 |
-
gripper_splat = root / 'log/gs/ckpts/gripper.splat' # gripper_new.splat
|
292 |
-
|
293 |
-
table_params = read_splat(table_splat) # numpy
|
294 |
-
|
295 |
-
## add table and gripper
|
296 |
-
# add table
|
297 |
-
t_pts, t_colors, t_scales, t_quats, t_opacities = table_params
|
298 |
-
t_pts = torch.tensor(t_pts).to(torch.float32).to(self.device)
|
299 |
-
t_colors = torch.tensor(t_colors).to(torch.float32).to(self.device)
|
300 |
-
t_scales = torch.tensor(t_scales).to(torch.float32).to(self.device)
|
301 |
-
t_quats = torch.tensor(t_quats).to(torch.float32).to(self.device)
|
302 |
-
t_opacities = torch.tensor(t_opacities).to(torch.float32).to(self.device)
|
303 |
-
|
304 |
-
# add table pos
|
305 |
-
t_pts = t_pts + torch.tensor([0, 0, 0.02]).to(torch.float32).to(self.device)
|
306 |
-
|
307 |
-
# add gripper
|
308 |
-
gripper_params = read_splat(gripper_splat) # numpy
|
309 |
-
|
310 |
-
g_pts, g_colors, g_scales, g_quats, g_opacities = gripper_params
|
311 |
-
g_pts = torch.tensor(g_pts).to(torch.float32).to(self.device)
|
312 |
-
g_colors = torch.tensor(g_colors).to(torch.float32).to(self.device)
|
313 |
-
g_scales = torch.tensor(g_scales).to(torch.float32).to(self.device)
|
314 |
-
g_quats = torch.tensor(g_quats).to(torch.float32).to(self.device)
|
315 |
-
g_opacities = torch.tensor(g_opacities).to(torch.float32).to(self.device)
|
316 |
-
|
317 |
-
# we do not do the gripper translation now because this will center the gripper in the data frame but not the viewer frame
|
318 |
-
|
319 |
-
self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities # data frame
|
320 |
-
self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities # data frame
|
321 |
-
|
322 |
-
# load other info
|
323 |
-
n_particles = self.cfg.sim.n_particles
|
324 |
-
self.state['clip_bound'] = torch.tensor([self.cfg.model.clip_bound], dtype=torch.float32)
|
325 |
-
self.state['enabled'] = torch.ones(n_particles, dtype=torch.bool)
|
326 |
-
|
327 |
-
def set_camera(self, w, h, intr, w2c=None, R=None, t=None, near=0.01, far=100.0):
|
328 |
-
if w2c is None:
|
329 |
-
assert R is not None and t is not None
|
330 |
-
w2c = Rt_to_w2c(R, t)
|
331 |
-
self.metadata = {
|
332 |
-
'w': w,
|
333 |
-
'h': h,
|
334 |
-
'k': intr,
|
335 |
-
'w2c': w2c,
|
336 |
-
}
|
337 |
-
self.config = {'near': near, 'far': far}
|
338 |
|
339 |
def load_eef(self, grippers=None, eef_t=None):
|
340 |
assert self.state['prev_key_pos'] is None
|
@@ -453,7 +379,7 @@ class DynamicsVisualizer:
|
|
453 |
self.colliders = colliders
|
454 |
|
455 |
# load ckpt
|
456 |
-
ckpt_path = root / 'log/
|
457 |
ckpt = torch.load(ckpt_path, map_location=self.torch_device)
|
458 |
|
459 |
material: nn.Module = PGNDModel(cfg)
|
@@ -704,6 +630,76 @@ class DynamicsVisualizer:
|
|
704 |
t_pts = t_pts @ R_viewer.T
|
705 |
t_quats = mat2quat(R_viewer @ quat2mat(t_quats))
|
706 |
t_pts += t_viewer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
707 |
|
708 |
g_mat = quat2mat(g_quats)
|
709 |
g_mat = R @ g_mat
|
@@ -720,7 +716,13 @@ class DynamicsVisualizer:
|
|
720 |
# TODO: center gripper in the viewer frame
|
721 |
g_pts_tip = g_pts[g_pts_tip_mask]
|
722 |
g_pts_tip_mean_xy = g_pts_tip[:, :2].mean(dim=0)
|
723 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
724 |
g_pts = g_pts + g_pts_translation
|
725 |
|
726 |
self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities
|
@@ -882,7 +884,7 @@ class DynamicsVisualizer:
|
|
882 |
center = (0, 0, 0.1)
|
883 |
distance = 0.7
|
884 |
elevation = 20
|
885 |
-
azimuth = 180.0
|
886 |
target = np.array(center)
|
887 |
theta = 90 + azimuth
|
888 |
z = distance * math.sin(math.radians(elevation))
|
@@ -930,8 +932,9 @@ class DynamicsVisualizer:
|
|
930 |
|
931 |
return rendervar_init
|
932 |
|
933 |
-
|
934 |
-
|
|
|
935 |
import warp as wp
|
936 |
wp.init()
|
937 |
gpus = [int(gpu) for gpu in self.cfg.gpus]
|
@@ -942,7 +945,7 @@ class DynamicsVisualizer:
|
|
942 |
self.wp_device = wp_devices[0]
|
943 |
self.torch_device = torch_devices[0]
|
944 |
|
945 |
-
in_dir = root / 'log/gs/ckpts/
|
946 |
batch_size = 1
|
947 |
num_steps = 1
|
948 |
num_particles = self.cfg.sim.n_particles
|
@@ -1028,7 +1031,7 @@ class DynamicsVisualizer:
|
|
1028 |
center = (0, 0, 0.1)
|
1029 |
distance = 0.7
|
1030 |
elevation = 20
|
1031 |
-
azimuth = 180.0
|
1032 |
target = np.array(center)
|
1033 |
theta = 90 + azimuth
|
1034 |
z = distance * math.sin(math.radians(elevation))
|
@@ -1191,6 +1194,14 @@ class DynamicsVisualizer:
|
|
1191 |
self.preprocess_metadata, self.state, self.params, \
|
1192 |
self.table_params, self.gripper_params, rendervar
|
1193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1194 |
@spaces.GPU
|
1195 |
def on_click_run_xplus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar):
|
1196 |
return self.run_command([5.0, 0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar)
|
@@ -1230,27 +1241,36 @@ class DynamicsVisualizer:
|
|
1230 |
|
1231 |
with gr.Row():
|
1232 |
gr.Markdown('### Project page: [https://kywind.github.io/pgnd](https://kywind.github.io/pgnd)')
|
1233 |
-
|
|
|
|
|
|
|
1234 |
with gr.Row():
|
1235 |
gr.Markdown(' '.join([
|
1236 |
-
'
|
1237 |
-
'- Click the "Reset" button to initialize the simulation with the predicted video and Gaussian splats. Due to compute limitations of Huggingface Space, each run may take a prolonged period (up to 30 seconds).\n',
|
1238 |
'- Use the buttons to move the gripper in the x, y, z directions. The gripper will move for a fixed length per click. The predicted video and Gaussian splats will be updated accordingly.\n',
|
1239 |
'- X-Y plane is the table surface, and Z is the height.\n',
|
1240 |
'- The predicted video from the previous step to the current step will be shown in the "Predicted video" section.\n',
|
1241 |
'- The Gaussian splats after the current step will be shown in the "Predicted Gaussians" section.\n',
|
1242 |
-
'- The simulation results may deviate from the initial shape due to accumulative prediction artifacts. Click the
|
1243 |
]))
|
1244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1245 |
|
1246 |
with gr.Row():
|
1247 |
|
1248 |
-
# with gr.Column(scale=2):
|
1249 |
-
# form_3dgs_orig = gr.Model3D(
|
1250 |
-
# label='Original Gaussian Splats',
|
1251 |
-
# value=None,
|
1252 |
-
# )
|
1253 |
-
|
1254 |
with gr.Column(scale=2):
|
1255 |
form_video = gr.Video(
|
1256 |
label='Predicted video',
|
@@ -1269,10 +1289,11 @@ class DynamicsVisualizer:
|
|
1269 |
)
|
1270 |
|
1271 |
# Layout
|
|
|
|
|
|
|
1272 |
with gr.Row():
|
1273 |
with gr.Column(scale=2):
|
1274 |
-
with gr.Row():
|
1275 |
-
run_reset = gr.Button("Reset")
|
1276 |
|
1277 |
with gr.Row():
|
1278 |
with gr.Column():
|
@@ -1294,61 +1315,15 @@ class DynamicsVisualizer:
|
|
1294 |
|
1295 |
with gr.Column(scale=2):
|
1296 |
_ = gr.Button(visible=False) # empty placeholder
|
1297 |
-
|
1298 |
-
# with gr.Row():
|
1299 |
-
|
1300 |
-
# # with gr.Column(scale=2):
|
1301 |
-
# # form_3dgs_orig = gr.Model3D(
|
1302 |
-
# # label='Original Gaussian Splats',
|
1303 |
-
# # value=None,
|
1304 |
-
# # )
|
1305 |
-
|
1306 |
-
# with gr.Column(scale=2):
|
1307 |
-
# form_video_2 = gr.Video(
|
1308 |
-
# label='Predicted video',
|
1309 |
-
# value=None,
|
1310 |
-
# format='mp4',
|
1311 |
-
# width=self.width,
|
1312 |
-
# height=self.height,
|
1313 |
-
# )
|
1314 |
-
|
1315 |
-
# with gr.Column(scale=2):
|
1316 |
-
# form_3dgs_pred_2 = gr.Model3D(
|
1317 |
-
# label='Predicted Gaussians',
|
1318 |
-
# height=self.height,
|
1319 |
-
# value=None,
|
1320 |
-
# clear_color=[0, 0, 0, 0],
|
1321 |
-
# )
|
1322 |
-
|
1323 |
-
# # Layout
|
1324 |
-
# with gr.Row():
|
1325 |
-
# with gr.Column(scale=2):
|
1326 |
-
# with gr.Row():
|
1327 |
-
# run_reset_2 = gr.Button("Reset")
|
1328 |
-
|
1329 |
-
# with gr.Row():
|
1330 |
-
# with gr.Column():
|
1331 |
-
# run_xminus_2 = gr.Button("x-")
|
1332 |
-
# with gr.Column():
|
1333 |
-
# run_xplus_2 = gr.Button("x+")
|
1334 |
-
|
1335 |
-
# with gr.Row():
|
1336 |
-
# with gr.Column():
|
1337 |
-
# run_yminus_2 = gr.Button("y-")
|
1338 |
-
# with gr.Column():
|
1339 |
-
# run_yplus_2 = gr.Button("y+")
|
1340 |
-
|
1341 |
-
# with gr.Row():
|
1342 |
-
# with gr.Column():
|
1343 |
-
# run_zminus_2 = gr.Button("z-")
|
1344 |
-
# with gr.Column():
|
1345 |
-
# run_zplus_2 = gr.Button("z+")
|
1346 |
-
|
1347 |
-
# with gr.Column(scale=2):
|
1348 |
-
# _ = gr.Button(visible=False) # empty placeholder
|
1349 |
|
1350 |
# Set up callbacks
|
1351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1352 |
inputs=[],
|
1353 |
outputs=[form_video, form_3dgs_pred,
|
1354 |
preprocess_metadata, state, params,
|
@@ -1396,35 +1371,6 @@ class DynamicsVisualizer:
|
|
1396 |
preprocess_metadata, state, params,
|
1397 |
table_params, gripper_params, rendervar])
|
1398 |
|
1399 |
-
# Set up callbacks
|
1400 |
-
# run_reset_2.click(self.reset_2,
|
1401 |
-
# inputs=[],
|
1402 |
-
# outputs=[form_video_2, form_3dgs_pred_2])
|
1403 |
-
|
1404 |
-
# run_xplus_2.click(self.on_click_run_xplus_2,
|
1405 |
-
# inputs=[],
|
1406 |
-
# outputs=[form_video_2, form_3dgs_pred_2])
|
1407 |
-
|
1408 |
-
# run_xminus_2.click(self.on_click_run_xminus_2,
|
1409 |
-
# inputs=[],
|
1410 |
-
# outputs=[form_video_2, form_3dgs_pred_2])
|
1411 |
-
|
1412 |
-
# run_yplus_2.click(self.on_click_run_yplus_2,
|
1413 |
-
# inputs=[],
|
1414 |
-
# outputs=[form_video_2, form_3dgs_pred_2])
|
1415 |
-
|
1416 |
-
# run_yminus_2.click(self.on_click_run_yminus_2,
|
1417 |
-
# inputs=[],
|
1418 |
-
# outputs=[form_video_2, form_3dgs_pred_2])
|
1419 |
-
|
1420 |
-
# run_zplus_2.click(self.on_click_run_zplus_2,
|
1421 |
-
# inputs=[],
|
1422 |
-
# outputs=[form_video_2, form_3dgs_pred_2])
|
1423 |
-
|
1424 |
-
# run_zminus_2.click(self.on_click_run_zminus_2,
|
1425 |
-
# inputs=[],
|
1426 |
-
# outputs=[form_video_2, form_3dgs_pred_2])
|
1427 |
-
|
1428 |
app.launch(share=share)
|
1429 |
|
1430 |
|
|
|
81 |
|
82 |
class DynamicsVisualizer:
|
83 |
|
84 |
+
def __init__(self, wp_device='cuda', torch_device='cuda'):
|
85 |
+
|
86 |
+
self.best_models = {
|
|
|
|
|
87 |
'cloth': ['cloth', 'train', 100000, [610, 650]],
|
88 |
'rope': ['rope', 'train', 100000, [651, 691]],
|
89 |
'paperbag': ['paperbag', 'train', 100000, [200, 220]],
|
|
|
91 |
'box': ['box', 'train', 100000, [306, 323]],
|
92 |
'bread': ['bread', 'train', 100000, [143, 163]],
|
93 |
}
|
|
|
94 |
task_name = 'rope'
|
95 |
+
self.init(task_name)
|
96 |
|
97 |
+
def init(self, task_name):
|
98 |
+
self.width = 640
|
99 |
+
self.height = 480
|
100 |
+
self.task_name = task_name
|
101 |
+
|
102 |
+
with open(root / f'log/{self.best_models[task_name][0]}/{self.best_models[task_name][1]}/hydra.yaml', 'r') as f:
|
103 |
config = yaml.load(f, Loader=yaml.CLoader)
|
104 |
cfg = OmegaConf.create(config)
|
105 |
|
106 |
+
cfg.iteration = self.best_models[task_name][2]
|
107 |
+
cfg.start_episode = self.best_models[task_name][3][0]
|
108 |
+
cfg.end_episode = self.best_models[task_name][3][1]
|
109 |
cfg.sim.num_steps = 1000
|
110 |
cfg.sim.gripper_forcing = False
|
111 |
cfg.sim.uniform = True
|
|
|
261 |
self.state['prev_key_pos'] = grippers[:, :3] # (1, 3)
|
262 |
# self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32)
|
263 |
self.state['gripper_radius'] = cfg.model.gripper_radius
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
|
265 |
def load_eef(self, grippers=None, eef_t=None):
|
266 |
assert self.state['prev_key_pos'] is None
|
|
|
379 |
self.colliders = colliders
|
380 |
|
381 |
# load ckpt
|
382 |
+
ckpt_path = root / f'log/{self.task_name}/train/ckpt/100000.pt'
|
383 |
ckpt = torch.load(ckpt_path, map_location=self.torch_device)
|
384 |
|
385 |
material: nn.Module = PGNDModel(cfg)
|
|
|
630 |
t_pts = t_pts @ R_viewer.T
|
631 |
t_quats = mat2quat(R_viewer @ quat2mat(t_quats))
|
632 |
t_pts += t_viewer
|
633 |
+
|
634 |
+
axes = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
|
635 |
+
dirs = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] # x, y, z axes
|
636 |
+
for ee in range(3):
|
637 |
+
gripper_direction = torch.tensor(dirs[ee], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3)
|
638 |
+
gripper_direction = gripper_direction / (torch.norm(gripper_direction, dim=-1, keepdim=True) + 1e-10) # normalize
|
639 |
+
|
640 |
+
R = self.preprocess_metadata['R']
|
641 |
+
# model frame to data frame
|
642 |
+
direction = gripper_direction @ R.T
|
643 |
+
|
644 |
+
n_grippers = 1
|
645 |
+
N = 200
|
646 |
+
length = 0.2
|
647 |
+
kk = 5
|
648 |
+
xyz_test = torch.zeros((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=t_pts.dtype)
|
649 |
+
|
650 |
+
if self.task_name == 'rope':
|
651 |
+
pos = torch.tensor([0.0, 0.0, 1.2], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3) # gripper position in model frame
|
652 |
+
else:
|
653 |
+
pos = torch.tensor([1.2, 0.0, 0.7], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3)
|
654 |
+
gripper_now_inv_xyz = self.inverse_preprocess_gripper(pos)
|
655 |
+
gripper_now_inv_rot = torch.eye(3, device=self.torch_device).unsqueeze(0).repeat(n_grippers, 1, 1)
|
656 |
+
|
657 |
+
center_point = torch.tensor([0.0, 0.0, 0.10], device=self.torch_device, dtype=t_pts.dtype).reshape(1, 3) # center point in gripper frame
|
658 |
+
gripper_center_inv_xyz = gripper_now_inv_xyz + \
|
659 |
+
torch.einsum('ijk,ik->ij', gripper_now_inv_rot, center_point) # (n_grippers, 3)
|
660 |
+
|
661 |
+
for i in range(N):
|
662 |
+
offset = i / N * length * direction
|
663 |
+
xyz_test[:, i] = gripper_center_inv_xyz + offset
|
664 |
+
|
665 |
+
if direction[0, 2] < 0.9 and direction[0, 2] > -0.9: # not vertical
|
666 |
+
direction_up = -direction + torch.tensor([0.0, 0.0, 0.5], device=self.torch_device, dtype=t_pts.dtype)
|
667 |
+
direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) # normalize
|
668 |
+
direction_down = -direction + torch.tensor([0.0, 0.0, -0.5], device=self.torch_device, dtype=t_pts.dtype)
|
669 |
+
direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) # normalize
|
670 |
+
else:
|
671 |
+
direction_up = -direction + torch.tensor([0.0, 0.5, 0.0], device=self.torch_device, dtype=t_pts.dtype)
|
672 |
+
direction_up = direction_up / (torch.norm(direction_up, dim=-1, keepdim=True) + 1e-10) # normalize
|
673 |
+
direction_down = -direction + torch.tensor([0.0, -0.5, 0.0], device=self.torch_device, dtype=t_pts.dtype)
|
674 |
+
direction_down = direction_down / (torch.norm(direction_down, dim=-1, keepdim=True) + 1e-10) # normalize
|
675 |
+
|
676 |
+
for i in range(N, N + N // kk):
|
677 |
+
offset = length * direction + (i - N) / N * length * direction_up
|
678 |
+
xyz_test[:, i] = gripper_center_inv_xyz + offset
|
679 |
+
|
680 |
+
for i in range(N + N // kk, N + N // kk + N // kk):
|
681 |
+
offset = length * direction + (i - N - N // kk) / N * length * direction_down
|
682 |
+
xyz_test[:, i] = gripper_center_inv_xyz + offset
|
683 |
+
|
684 |
+
color_test = torch.zeros_like(xyz_test, device=self.torch_device, dtype=t_pts.dtype)
|
685 |
+
color_test[:, :, 0] = axes[ee][0]
|
686 |
+
color_test[:, :, 1] = axes[ee][1]
|
687 |
+
color_test[:, :, 2] = axes[ee][2]
|
688 |
+
quat_test = torch.zeros((n_grippers, N + N // kk + N // kk, 4), device=self.torch_device, dtype=t_pts.dtype)
|
689 |
+
quat_test[:, :, 0] = 1.0 # identity quaternion
|
690 |
+
opa_test = torch.ones((n_grippers, N + N // kk + N // kk, 1), device=self.torch_device, dtype=t_pts.dtype)
|
691 |
+
scales_test = torch.ones((n_grippers, N + N // kk + N // kk, 3), device=self.torch_device, dtype=t_pts.dtype) * 0.002
|
692 |
+
|
693 |
+
t_pts = torch.cat([t_pts, xyz_test.reshape(-1, 3)], dim=0)
|
694 |
+
t_colors = torch.cat([t_colors, color_test.reshape(-1, 3)], dim=0)
|
695 |
+
t_quats = torch.cat([t_quats, quat_test.reshape(-1, 4)], dim=0)
|
696 |
+
t_opacities = torch.cat([t_opacities, opa_test.reshape(-1, 1)], dim=0)
|
697 |
+
t_scales = torch.cat([t_scales, scales_test.reshape(-1, 3)], dim=0)
|
698 |
+
|
699 |
+
t_pts = t_pts.reshape(-1, 3)
|
700 |
+
t_colors = t_colors.reshape(-1, 3)
|
701 |
+
t_quats = t_quats.reshape(-1, 4)
|
702 |
+
t_opacities = t_opacities.reshape(-1, 1)
|
703 |
|
704 |
g_mat = quat2mat(g_quats)
|
705 |
g_mat = R @ g_mat
|
|
|
716 |
# TODO: center gripper in the viewer frame
|
717 |
g_pts_tip = g_pts[g_pts_tip_mask]
|
718 |
g_pts_tip_mean_xy = g_pts_tip[:, :2].mean(dim=0)
|
719 |
+
|
720 |
+
if self.task_name == 'rope':
|
721 |
+
g_pts_translation = torch.tensor([-g_pts_tip_mean_xy[0], -g_pts_tip_mean_xy[1], -0.23]).to(torch.float32).to(self.device)
|
722 |
+
elif self.task_name == 'sloth':
|
723 |
+
g_pts_translation = torch.tensor([-g_pts_tip_mean_xy[0], -g_pts_tip_mean_xy[1], -0.32]).to(torch.float32).to(self.device)
|
724 |
+
else:
|
725 |
+
raise NotImplementedError(f"Task {self.task_name} not implemented for gripper translation.")
|
726 |
g_pts = g_pts + g_pts_translation
|
727 |
|
728 |
self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities
|
|
|
884 |
center = (0, 0, 0.1)
|
885 |
distance = 0.7
|
886 |
elevation = 20
|
887 |
+
azimuth = 180.0 if self.task_name == 'rope' else 120.0
|
888 |
target = np.array(center)
|
889 |
theta = 90 + azimuth
|
890 |
z = distance * math.sin(math.radians(elevation))
|
|
|
932 |
|
933 |
return rendervar_init
|
934 |
|
935 |
+
def reset(self, task_name, scene_name):
|
936 |
+
self.init(task_name)
|
937 |
+
|
938 |
import warp as wp
|
939 |
wp.init()
|
940 |
gpus = [int(gpu) for gpu in self.cfg.gpus]
|
|
|
945 |
self.wp_device = wp_devices[0]
|
946 |
self.torch_device = torch_devices[0]
|
947 |
|
948 |
+
in_dir = root / f'log/gs/ckpts/{scene_name}'
|
949 |
batch_size = 1
|
950 |
num_steps = 1
|
951 |
num_particles = self.cfg.sim.n_particles
|
|
|
1031 |
center = (0, 0, 0.1)
|
1032 |
distance = 0.7
|
1033 |
elevation = 20
|
1034 |
+
azimuth = 180.0 if self.task_name == 'rope' else 120.0
|
1035 |
target = np.array(center)
|
1036 |
theta = 90 + azimuth
|
1037 |
z = distance * math.sin(math.radians(elevation))
|
|
|
1194 |
self.preprocess_metadata, self.state, self.params, \
|
1195 |
self.table_params, self.gripper_params, rendervar
|
1196 |
|
1197 |
+
@spaces.GPU
|
1198 |
+
def reset_rope(self):
|
1199 |
+
return self.reset('rope', 'rope_scene_1')
|
1200 |
+
|
1201 |
+
@spaces.GPU
|
1202 |
+
def reset_plush(self):
|
1203 |
+
return self.reset('sloth', 'sloth_scene_1')
|
1204 |
+
|
1205 |
@spaces.GPU
|
1206 |
def on_click_run_xplus(self, preprocess_metadata, state, params, table_params, gripper_params, rendervar):
|
1207 |
return self.run_command([5.0, 0, 0], preprocess_metadata, state, params, table_params, gripper_params, rendervar)
|
|
|
1241 |
|
1242 |
with gr.Row():
|
1243 |
gr.Markdown('### Project page: [https://kywind.github.io/pgnd](https://kywind.github.io/pgnd)')
|
1244 |
+
|
1245 |
+
with gr.Row():
|
1246 |
+
gr.Markdown('### Instructions:')
|
1247 |
+
|
1248 |
with gr.Row():
|
1249 |
gr.Markdown(' '.join([
|
1250 |
+
'- Click the "Reset-\<object\>" button to initialize the simulation with the predicted video and Gaussian splats. Due to compute limitations of Huggingface Space, each run may take a prolonged period (up to 30 seconds).\n',
|
|
|
1251 |
'- Use the buttons to move the gripper in the x, y, z directions. The gripper will move for a fixed length per click. The predicted video and Gaussian splats will be updated accordingly.\n',
|
1252 |
'- X-Y plane is the table surface, and Z is the height.\n',
|
1253 |
'- The predicted video from the previous step to the current step will be shown in the "Predicted video" section.\n',
|
1254 |
'- The Gaussian splats after the current step will be shown in the "Predicted Gaussians" section.\n',
|
1255 |
+
'- The simulation results may deviate from the initial shape due to accumulative prediction artifacts. Click the Reset button to reset the simulation state and reinitialize the predicted video and Gaussian splats.\n',
|
1256 |
]))
|
1257 |
+
|
1258 |
+
with gr.Row():
|
1259 |
+
gr.Markdown('### Select a scene to reset the simulation:')
|
1260 |
+
|
1261 |
+
with gr.Row():
|
1262 |
+
with gr.Column(scale=2):
|
1263 |
+
with gr.Row():
|
1264 |
+
with gr.Column():
|
1265 |
+
run_reset_plush = gr.Button("Reset - Plush")
|
1266 |
+
with gr.Column():
|
1267 |
+
run_reset_rope = gr.Button("Reset - Rope")
|
1268 |
+
|
1269 |
+
with gr.Column(scale=2):
|
1270 |
+
_ = gr.Button(visible=False) # empty placeholder
|
1271 |
|
1272 |
with gr.Row():
|
1273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
1274 |
with gr.Column(scale=2):
|
1275 |
form_video = gr.Video(
|
1276 |
label='Predicted video',
|
|
|
1289 |
)
|
1290 |
|
1291 |
# Layout
|
1292 |
+
with gr.Row():
|
1293 |
+
gr.Markdown('### Control the gripper to move in the x, y, z directions:')
|
1294 |
+
|
1295 |
with gr.Row():
|
1296 |
with gr.Column(scale=2):
|
|
|
|
|
1297 |
|
1298 |
with gr.Row():
|
1299 |
with gr.Column():
|
|
|
1315 |
|
1316 |
with gr.Column(scale=2):
|
1317 |
_ = gr.Button(visible=False) # empty placeholder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1318 |
|
1319 |
# Set up callbacks
|
1320 |
+
run_reset_rope.click(self.reset_rope,
|
1321 |
+
inputs=[],
|
1322 |
+
outputs=[form_video, form_3dgs_pred,
|
1323 |
+
preprocess_metadata, state, params,
|
1324 |
+
table_params, gripper_params, rendervar])
|
1325 |
+
|
1326 |
+
run_reset_plush.click(self.reset_plush,
|
1327 |
inputs=[],
|
1328 |
outputs=[form_video, form_3dgs_pred,
|
1329 |
preprocess_metadata, state, params,
|
|
|
1371 |
preprocess_metadata, state, params,
|
1372 |
table_params, gripper_params, rendervar])
|
1373 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1374 |
app.launch(share=share)
|
1375 |
|
1376 |
|
src/experiments/log/gs/ckpts/sloth_scene_1/eef_xyz.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
-1.1209631912410259247e-02
|
2 |
+
-0.754003190994262695e-02
|
3 |
+
-1.769125705957412720e-01
|
src/experiments/log/gs/ckpts/sloth_scene_1/eef_xyz_old.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
-4.266144707798957825e-03
|
2 |
+
-6.183005869388580322e-02
|
3 |
+
-1.841607391834259033e-01
|
src/experiments/log/gs/ckpts/sloth_scene_1/gripper.splat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:893d6c2608a022685ab7cf8d044f044afe6eda3248d2baf1c1ac6d55160f1041
|
3 |
+
size 1151264
|
src/experiments/log/gs/ckpts/sloth_scene_1/gripper_old.splat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:def7c4ecddf10a491a3717bbc271ab55b9ab35437452cf6d61666fa2ccbd7883
|
3 |
+
size 1212288
|
src/experiments/log/gs/ckpts/sloth_scene_1/object.splat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:94110d55ce1ba8dd3021f70b3864b7b833f22cf94757b96640d8f9c74f0a2c1c
|
3 |
+
size 4481248
|
src/experiments/log/gs/ckpts/sloth_scene_1/table.splat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ed2090607151595c8b4b7ec44ffde9196366c2201b3bbd21e03bf924cf54e29c
|
3 |
+
size 7051104
|
src/experiments/log/gs/temp/form_video.mp4
CHANGED
Binary files a/src/experiments/log/gs/temp/form_video.mp4 and b/src/experiments/log/gs/temp/form_video.mp4 differ
|
|
src/experiments/log/gs/temp/form_video_init.mp4
CHANGED
Binary files a/src/experiments/log/gs/temp/form_video_init.mp4 and b/src/experiments/log/gs/temp/form_video_init.mp4 differ
|
|
src/experiments/log/gs/temp/gs_pred.splat
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4d1d42fd7779673768a2604429439aba5c0228c08350d4c0173d6f7cce89a293
|
3 |
+
size 12719456
|
src/experiments/log/sloth/train/ckpt/100000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1ce7f86a40058c2680784ac40f633a67e00e9ce8af8a6111acc3362d71d3b052
|
3 |
+
size 3374922
|
src/experiments/log/sloth/train/hydra.yaml
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
material:
|
3 |
+
requires_grad: true
|
4 |
+
output_scale: 1.0
|
5 |
+
input_scale: 2.0
|
6 |
+
radius: 0.2
|
7 |
+
absolute_y: false
|
8 |
+
pe_num_func_res: 0
|
9 |
+
friction:
|
10 |
+
value: 0.0
|
11 |
+
requires_grad: false
|
12 |
+
ckpt: null
|
13 |
+
clip_bound: 1.5
|
14 |
+
eef_t:
|
15 |
+
- 0.0
|
16 |
+
- 0.0
|
17 |
+
- 0.01
|
18 |
+
gripper_radius: 0.04
|
19 |
+
render:
|
20 |
+
width: 512
|
21 |
+
height: 512
|
22 |
+
skip_frame: 1
|
23 |
+
bound: 1.5
|
24 |
+
fps: 5
|
25 |
+
radius_scale: 500
|
26 |
+
center:
|
27 |
+
- 0.5
|
28 |
+
- 0.3
|
29 |
+
- 0.5
|
30 |
+
distance: 1.4
|
31 |
+
azimuth: -125
|
32 |
+
elevation: 30
|
33 |
+
reflectance:
|
34 |
+
- 0.92941176
|
35 |
+
- 0.32941176
|
36 |
+
- 0.23137255
|
37 |
+
sim:
|
38 |
+
num_steps_train: 5
|
39 |
+
num_steps: 1000
|
40 |
+
interval: 1
|
41 |
+
num_grids:
|
42 |
+
- 50
|
43 |
+
- 50
|
44 |
+
- 50
|
45 |
+
- 0.02
|
46 |
+
dt: 0.1
|
47 |
+
bound: 3
|
48 |
+
eps: 1.0e-07
|
49 |
+
skip_frame: 1
|
50 |
+
num_grippers: 1
|
51 |
+
preprocess_scale: 1.0
|
52 |
+
preprocess_with_table: true
|
53 |
+
n_particles: 1000
|
54 |
+
gripper_forcing: true
|
55 |
+
gripper_points: false
|
56 |
+
n_history: 2
|
57 |
+
uniform: false
|
58 |
+
train:
|
59 |
+
name: sloth/train
|
60 |
+
dataset_name: sloth/dataset
|
61 |
+
source_dataset_name: data/sloth_merged/sub_episodes_v
|
62 |
+
num_iterations: 100000
|
63 |
+
resume_iteration: 0
|
64 |
+
batch_size: 32
|
65 |
+
num_workers: 8
|
66 |
+
material_lr: 0.0001
|
67 |
+
material_wd: 0.0
|
68 |
+
material_grad_max_norm: 0.1
|
69 |
+
training_start_episode: 0
|
70 |
+
training_end_episode: 113
|
71 |
+
eval_start_episode: 113
|
72 |
+
eval_end_episode: 133
|
73 |
+
iteration_log_interval: 10
|
74 |
+
iteration_save_interval: 1000
|
75 |
+
iteration_eval_interval: 10000
|
76 |
+
loss_factor: 1.0
|
77 |
+
loss_factor_v: 0.0
|
78 |
+
loss_factor_x: 100.0
|
79 |
+
friction_lr: 0.1
|
80 |
+
friction_wd: 0.0
|
81 |
+
friction_grad_max_norm: 0.1
|
82 |
+
dataset_load_skip_frame: 3
|
83 |
+
dataset_skip_frame: 1
|
84 |
+
downsample: false
|
85 |
+
use_pv: false
|
86 |
+
use_gs: false
|
87 |
+
dataset_non_overwrite: true
|
88 |
+
seed: 0
|
89 |
+
cpu: 0
|
90 |
+
num_cpus: 128
|
91 |
+
gpus:
|
92 |
+
- 0
|
93 |
+
overwrite: false
|
94 |
+
resume: true
|
95 |
+
debug: false
|
src/experiments/log/temp/0000.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/experiments/log/temp/0001.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/experiments/log/temp/0002.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/experiments/log/temp/0003.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/experiments/log/temp/0004.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/experiments/log/temp/0005.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/experiments/log/temp/0006.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/experiments/log/temp/0007.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/experiments/log/temp/0008.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/experiments/log/temp/0009.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/experiments/log/temp/0010.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/experiments/log/temp/0011.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/experiments/log/temp/0012.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/experiments/log/temp/0013.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/experiments/log/temp/0014.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/experiments/log/temp_init/0000.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|