kaifz commited on
Commit
dfdc1cc
·
1 Parent(s): f7aecd1
Files changed (1) hide show
  1. app.py +99 -12
app.py CHANGED
@@ -976,10 +976,12 @@ class DynamicsVisualizer:
976
  self.params[k][k2] = v2.detach().cpu() if isinstance(v2, torch.Tensor) else v2
977
  else:
978
  self.params[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v
979
- for k, v in self.table_params.items():
980
- self.table_params[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v
981
- for k, v in self.gripper_params.items():
982
- self.gripper_params[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v
 
 
983
 
984
  form_video = gr.Video(
985
  label='Predicted video',
@@ -1019,10 +1021,12 @@ class DynamicsVisualizer:
1019
  self.params[k][k2] = v2.to(self.torch_device) if isinstance(v2, torch.Tensor) else v2
1020
  else:
1021
  self.params[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v
1022
- for k, v in self.table_params.items():
1023
- self.table_params[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v
1024
- for k, v in self.gripper_params.items():
1025
- self.gripper_params[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v
 
 
1026
 
1027
  # im_list = []
1028
  for i in range(15):
@@ -1086,10 +1090,12 @@ class DynamicsVisualizer:
1086
  self.params[k][k2] = v2.detach().cpu() if isinstance(v2, torch.Tensor) else v2
1087
  else:
1088
  self.params[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v
1089
- for k, v in self.table_params.items():
1090
- self.table_params[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v
1091
- for k, v in self.gripper_params.items():
1092
- self.gripper_params[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v
 
 
1093
 
1094
  make_video(root / 'log/temp', root / f'log/gs/temp/form_video.mp4', '%04d.png', 5)
1095
 
@@ -1206,6 +1212,58 @@ class DynamicsVisualizer:
1206
 
1207
  with gr.Column(scale=2):
1208
  _ = gr.Button(visible=False) # empty placeholder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1209
 
1210
  # Set up callbacks
1211
  run_reset.click(self.reset,
@@ -1236,6 +1294,35 @@ class DynamicsVisualizer:
1236
  inputs=[],
1237
  outputs=[form_video, form_3dgs_pred])
1238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1239
  app.launch(share=share)
1240
 
1241
 
 
976
  self.params[k][k2] = v2.detach().cpu() if isinstance(v2, torch.Tensor) else v2
977
  else:
978
  self.params[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v
979
+ self.table_params = tuple(
980
+ v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.table_params
981
+ )
982
+ self.gripper_params = tuple(
983
+ v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.gripper_params
984
+ )
985
 
986
  form_video = gr.Video(
987
  label='Predicted video',
 
1021
  self.params[k][k2] = v2.to(self.torch_device) if isinstance(v2, torch.Tensor) else v2
1022
  else:
1023
  self.params[k] = v.to(self.torch_device) if isinstance(v, torch.Tensor) else v
1024
+ self.table_params = tuple(
1025
+ v.to(self.torch_device) if isinstance(v, torch.Tensor) else v for v in self.table_params
1026
+ )
1027
+ self.gripper_params = tuple(
1028
+ v.to(self.torch_device) if isinstance(v, torch.Tensor) else v for v in self.gripper_params
1029
+ )
1030
 
1031
  # im_list = []
1032
  for i in range(15):
 
1090
  self.params[k][k2] = v2.detach().cpu() if isinstance(v2, torch.Tensor) else v2
1091
  else:
1092
  self.params[k] = v.detach().cpu() if isinstance(v, torch.Tensor) else v
1093
+ self.table_params = tuple(
1094
+ v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.table_params
1095
+ )
1096
+ self.gripper_params = tuple(
1097
+ v.detach().cpu() if isinstance(v, torch.Tensor) else v for v in self.gripper_params
1098
+ )
1099
 
1100
  make_video(root / 'log/temp', root / f'log/gs/temp/form_video.mp4', '%04d.png', 5)
1101
 
 
1212
 
1213
  with gr.Column(scale=2):
1214
  _ = gr.Button(visible=False) # empty placeholder
1215
+
1216
+ # with gr.Row():
1217
+
1218
+ # # with gr.Column(scale=2):
1219
+ # # form_3dgs_orig = gr.Model3D(
1220
+ # # label='Original Gaussian Splats',
1221
+ # # value=None,
1222
+ # # )
1223
+
1224
+ # with gr.Column(scale=2):
1225
+ # form_video_2 = gr.Video(
1226
+ # label='Predicted video',
1227
+ # value=None,
1228
+ # format='mp4',
1229
+ # width=self.width,
1230
+ # height=self.height,
1231
+ # )
1232
+
1233
+ # with gr.Column(scale=2):
1234
+ # form_3dgs_pred_2 = gr.Model3D(
1235
+ # label='Predicted Gaussians',
1236
+ # height=self.height,
1237
+ # value=None,
1238
+ # clear_color=[0, 0, 0, 0],
1239
+ # )
1240
+
1241
+ # # Layout
1242
+ # with gr.Row():
1243
+ # with gr.Column(scale=2):
1244
+ # with gr.Row():
1245
+ # run_reset_2 = gr.Button("Reset")
1246
+
1247
+ # with gr.Row():
1248
+ # with gr.Column():
1249
+ # run_xminus_2 = gr.Button("x-")
1250
+ # with gr.Column():
1251
+ # run_xplus_2 = gr.Button("x+")
1252
+
1253
+ # with gr.Row():
1254
+ # with gr.Column():
1255
+ # run_yminus_2 = gr.Button("y-")
1256
+ # with gr.Column():
1257
+ # run_yplus_2 = gr.Button("y+")
1258
+
1259
+ # with gr.Row():
1260
+ # with gr.Column():
1261
+ # run_zminus_2 = gr.Button("z-")
1262
+ # with gr.Column():
1263
+ # run_zplus_2 = gr.Button("z+")
1264
+
1265
+ # with gr.Column(scale=2):
1266
+ # _ = gr.Button(visible=False) # empty placeholder
1267
 
1268
  # Set up callbacks
1269
  run_reset.click(self.reset,
 
1294
  inputs=[],
1295
  outputs=[form_video, form_3dgs_pred])
1296
 
1297
+ # Set up callbacks
1298
+ # run_reset_2.click(self.reset_2,
1299
+ # inputs=[],
1300
+ # outputs=[form_video_2, form_3dgs_pred_2])
1301
+
1302
+ # run_xplus_2.click(self.on_click_run_xplus_2,
1303
+ # inputs=[],
1304
+ # outputs=[form_video_2, form_3dgs_pred_2])
1305
+
1306
+ # run_xminus_2.click(self.on_click_run_xminus_2,
1307
+ # inputs=[],
1308
+ # outputs=[form_video_2, form_3dgs_pred_2])
1309
+
1310
+ # run_yplus_2.click(self.on_click_run_yplus_2,
1311
+ # inputs=[],
1312
+ # outputs=[form_video_2, form_3dgs_pred_2])
1313
+
1314
+ # run_yminus_2.click(self.on_click_run_yminus_2,
1315
+ # inputs=[],
1316
+ # outputs=[form_video_2, form_3dgs_pred_2])
1317
+
1318
+ # run_zplus_2.click(self.on_click_run_zplus_2,
1319
+ # inputs=[],
1320
+ # outputs=[form_video_2, form_3dgs_pred_2])
1321
+
1322
+ # run_zminus_2.click(self.on_click_run_zminus_2,
1323
+ # inputs=[],
1324
+ # outputs=[form_video_2, form_3dgs_pred_2])
1325
+
1326
  app.launch(share=share)
1327
 
1328