update
Browse files
app.py
CHANGED
@@ -27,8 +27,8 @@ import yaml
|
|
27 |
import matplotlib.pyplot as plt
|
28 |
from sklearn.neighbors import NearestNeighbors
|
29 |
import spaces
|
30 |
-
|
31 |
-
|
32 |
|
33 |
def install_cuda_toolkit():
|
34 |
# CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
|
@@ -51,7 +51,7 @@ def install_cuda_toolkit():
|
|
51 |
install_cuda_toolkit()
|
52 |
|
53 |
gs_path = Path(__file__).parent / "src/third-party/diff-gaussian-rasterization-w-depth"
|
54 |
-
subprocess.check_call(["pip", "install", "-e", str(gs_path)])
|
55 |
site.main() # re-processes every *.pth in site-packages
|
56 |
importlib.invalidate_caches()
|
57 |
diff_gaussian_rasterization = importlib.import_module("diff_gaussian_rasterization")
|
@@ -113,7 +113,6 @@ def fps_new(x, enabled, n, device, random_start=False):
|
|
113 |
|
114 |
class DynamicsVisualizer:
|
115 |
|
116 |
-
@spaces.GPU
|
117 |
def __init__(self):
|
118 |
wp.init()
|
119 |
|
@@ -143,7 +142,7 @@ class DynamicsVisualizer:
|
|
143 |
cfg.sim.uniform = True
|
144 |
cfg.sim.use_pv = False
|
145 |
|
146 |
-
device = torch.device('cuda
|
147 |
|
148 |
self.cfg = cfg
|
149 |
self.device = device
|
@@ -156,14 +155,6 @@ class DynamicsVisualizer:
|
|
156 |
self.dt_base = cfg.sim.dt
|
157 |
self.high_freq_pred = True
|
158 |
|
159 |
-
gpus = [int(gpu) for gpu in cfg.gpus]
|
160 |
-
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus]
|
161 |
-
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus]
|
162 |
-
device_count = len(torch_devices)
|
163 |
-
assert device_count == 1
|
164 |
-
self.wp_device = wp_devices[0]
|
165 |
-
self.torch_device = torch_devices[0]
|
166 |
-
|
167 |
seed = cfg.seed
|
168 |
random.seed(seed)
|
169 |
np.random.seed(seed)
|
@@ -172,8 +163,7 @@ class DynamicsVisualizer:
|
|
172 |
torch.backends.cudnn.benchmark = True
|
173 |
|
174 |
self.clear()
|
175 |
-
|
176 |
-
@spaces.GPU
|
177 |
def clear(self, clear_params=True):
|
178 |
self.metadata = {}
|
179 |
self.config = {}
|
@@ -206,7 +196,6 @@ class DynamicsVisualizer:
|
|
206 |
self.material = None
|
207 |
self.friction = None
|
208 |
|
209 |
-
@spaces.GPU
|
210 |
def load_scaniverse(self, data_path):
|
211 |
|
212 |
### load splat params
|
@@ -303,7 +292,6 @@ class DynamicsVisualizer:
|
|
303 |
# self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32)
|
304 |
self.state['gripper_radius'] = cfg.model.gripper_radius
|
305 |
|
306 |
-
@spaces.GPU
|
307 |
def load_params(self, params_path, remove_low_opa=True, remove_black=False):
|
308 |
pts, colors, scales, quats, opacities = read_splat(params_path)
|
309 |
|
@@ -369,7 +357,6 @@ class DynamicsVisualizer:
|
|
369 |
self.state['clip_bound'] = torch.tensor([self.cfg.model.clip_bound], dtype=torch.float32)
|
370 |
self.state['enabled'] = torch.ones(n_particles, dtype=torch.bool)
|
371 |
|
372 |
-
@spaces.GPU
|
373 |
def set_camera(self, w, h, intr, w2c=None, R=None, t=None, near=0.01, far=100.0):
|
374 |
if w2c is None:
|
375 |
assert R is not None and t is not None
|
@@ -382,7 +369,6 @@ class DynamicsVisualizer:
|
|
382 |
}
|
383 |
self.config = {'near': near, 'far': far}
|
384 |
|
385 |
-
@spaces.GPU
|
386 |
def load_eef(self, grippers=None, eef_t=None):
|
387 |
assert self.state['prev_key_pos'] is None
|
388 |
|
@@ -405,7 +391,6 @@ class DynamicsVisualizer:
|
|
405 |
# self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32) + 0.001
|
406 |
self.state['gripper_radius'] = self.cfg.model.gripper_radius
|
407 |
|
408 |
-
@spaces.GPU
|
409 |
def load_preprocess_metadata(self, p_x_orig):
|
410 |
cfg = self.cfg
|
411 |
dx = cfg.sim.num_grids[-1]
|
@@ -442,7 +427,6 @@ class DynamicsVisualizer:
|
|
442 |
'global_translation': global_translation,
|
443 |
}
|
444 |
|
445 |
-
@spaces.GPU
|
446 |
@torch.no_grad
|
447 |
def render(self, render_data, cam_id, bg=[0.7, 0.7, 0.7]):
|
448 |
render_data = {k: v.to(self.device) for k, v in render_data.items()}
|
@@ -452,7 +436,6 @@ class DynamicsVisualizer:
|
|
452 |
im, _, depth, = GaussianRasterizer(raster_settings=cam)(**render_data)
|
453 |
return im, depth
|
454 |
|
455 |
-
@spaces.GPU
|
456 |
def knn_relations(self, bones):
|
457 |
k = self.k_rel
|
458 |
knn = NearestNeighbors(n_neighbors=k+1, algorithm='kd_tree').fit(bones.detach().cpu().numpy())
|
@@ -460,7 +443,6 @@ class DynamicsVisualizer:
|
|
460 |
indices = indices[:, 1:] # exclude self
|
461 |
return indices
|
462 |
|
463 |
-
@spaces.GPU
|
464 |
def knn_weights_brute(self, bones, pts):
|
465 |
k = self.k_wgt
|
466 |
dist = torch.norm(pts[:, None] - bones, dim=-1) # (n_pts, n_bones)
|
@@ -473,7 +455,6 @@ class DynamicsVisualizer:
|
|
473 |
weights_all[torch.arange(pts.shape[0])[:, None], indices] = weights
|
474 |
return weights_all
|
475 |
|
476 |
-
@spaces.GPU
|
477 |
def update_camera(self, k, w2c, w=None, h=None, near=0.01, far=100.0):
|
478 |
self.metadata['k'] = k
|
479 |
self.metadata['w2c'] = w2c
|
@@ -484,7 +465,6 @@ class DynamicsVisualizer:
|
|
484 |
self.config['near'] = near
|
485 |
self.config['far'] = far
|
486 |
|
487 |
-
@spaces.GPU
|
488 |
def init_model(self, batch_size, num_steps, num_particles, ckpt_path=None):
|
489 |
self.cfg.sim.num_steps = num_steps
|
490 |
cfg = self.cfg
|
@@ -520,13 +500,11 @@ class DynamicsVisualizer:
|
|
520 |
self.material = material
|
521 |
self.friction = friction
|
522 |
|
523 |
-
@spaces.GPU
|
524 |
def reload_model(self, num_steps): # only change num_steps
|
525 |
self.cfg.sim.num_steps = num_steps
|
526 |
sim = CacheDiffSimWithFrictionBatch(self.cfg, num_steps, 1, self.wp_device, requires_grad=True)
|
527 |
self.sim = sim
|
528 |
|
529 |
-
@spaces.GPU
|
530 |
@torch.no_grad
|
531 |
def step(self):
|
532 |
cfg = self.cfg
|
@@ -616,7 +594,6 @@ class DynamicsVisualizer:
|
|
616 |
self.state['sub_pos'] = None
|
617 |
# self.state['sub_pos_timestamps'] = None
|
618 |
|
619 |
-
@spaces.GPU
|
620 |
def preprocess_x(self, p_x): # viewer frame to model frame (not data frame)
|
621 |
R = self.preprocess_metadata['R']
|
622 |
R_viewer = self.preprocess_metadata['R_viewer']
|
@@ -634,7 +611,6 @@ class DynamicsVisualizer:
|
|
634 |
|
635 |
return p_x
|
636 |
|
637 |
-
@spaces.GPU
|
638 |
def preprocess_gripper(self, grippers): # viewer frame to model frame (not data frame)
|
639 |
R = self.preprocess_metadata['R']
|
640 |
R_viewer = self.preprocess_metadata['R_viewer']
|
@@ -647,7 +623,6 @@ class DynamicsVisualizer:
|
|
647 |
|
648 |
return grippers
|
649 |
|
650 |
-
@spaces.GPU
|
651 |
def inverse_preprocess_x(self, p_x): # model frame (not data frame) to viewer frame
|
652 |
R = self.preprocess_metadata['R']
|
653 |
R_viewer = self.preprocess_metadata['R_viewer']
|
@@ -660,7 +635,6 @@ class DynamicsVisualizer:
|
|
660 |
|
661 |
return p_x
|
662 |
|
663 |
-
@spaces.GPU
|
664 |
def inverse_preprocess_gripper(self, grippers): # model frame (not data frame) to viewer frame
|
665 |
R = self.preprocess_metadata['R']
|
666 |
R_viewer = self.preprocess_metadata['R_viewer']
|
@@ -673,7 +647,6 @@ class DynamicsVisualizer:
|
|
673 |
|
674 |
return grippers
|
675 |
|
676 |
-
@spaces.GPU
|
677 |
def rotate(self, params, rot_mat):
|
678 |
scale = np.linalg.norm(rot_mat, axis=1, keepdims=True)
|
679 |
|
@@ -686,7 +659,6 @@ class DynamicsVisualizer:
|
|
686 |
}
|
687 |
return params
|
688 |
|
689 |
-
@spaces.GPU
|
690 |
def preprocess_gs(self, params):
|
691 |
if isinstance(params, dict):
|
692 |
xyz = params['means3D']
|
@@ -736,7 +708,6 @@ class DynamicsVisualizer:
|
|
736 |
|
737 |
return params
|
738 |
|
739 |
-
@spaces.GPU
|
740 |
def preprocess_bg_gs(self):
|
741 |
t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params
|
742 |
g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params
|
@@ -784,7 +755,6 @@ class DynamicsVisualizer:
|
|
784 |
self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities
|
785 |
self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities
|
786 |
|
787 |
-
@spaces.GPU
|
788 |
def update_rendervar(self, rendervar):
|
789 |
p_x = self.state['x']
|
790 |
p_x_viewer = self.inverse_preprocess_x(p_x)
|
@@ -861,7 +831,6 @@ class DynamicsVisualizer:
|
|
861 |
|
862 |
return rendervar, rendervar_full
|
863 |
|
864 |
-
@spaces.GPU
|
865 |
def reset_state(self, params, visualize_image=False, init=False):
|
866 |
xyz_0 = params['means3D']
|
867 |
rgb_0 = params['rgb_colors']
|
@@ -931,6 +900,15 @@ class DynamicsVisualizer:
|
|
931 |
|
932 |
@spaces.GPU
|
933 |
def reset(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
934 |
params = self.preprocess_gs(self.params)
|
935 |
if self.with_bg:
|
936 |
self.preprocess_bg_gs()
|
@@ -971,7 +949,6 @@ class DynamicsVisualizer:
|
|
971 |
|
972 |
return form_video, form_3dgs_pred
|
973 |
|
974 |
-
@spaces.GPU
|
975 |
def run_command(self, unit_command):
|
976 |
|
977 |
os.system('rm -rf ' + str(root / 'log/temp/*'))
|
@@ -1058,25 +1035,78 @@ class DynamicsVisualizer:
|
|
1058 |
)
|
1059 |
return form_video, form_3dgs_pred
|
1060 |
|
|
|
1061 |
def on_click_run_xplus(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1062 |
return self.run_command([5.0, 0, 0])
|
1063 |
|
|
|
1064 |
def on_click_run_xminus(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1065 |
return self.run_command([-5.0, 0, 0])
|
1066 |
|
|
|
1067 |
def on_click_run_yplus(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1068 |
return self.run_command([0, 5.0, 0])
|
1069 |
|
|
|
1070 |
def on_click_run_yminus(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1071 |
return self.run_command([0, -5.0, 0])
|
1072 |
|
|
|
1073 |
def on_click_run_zplus(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1074 |
return self.run_command([0, 0, 5.0])
|
1075 |
|
|
|
1076 |
def on_click_run_zminus(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1077 |
return self.run_command([0, 0, -5.0])
|
1078 |
|
1079 |
-
@spaces.GPU
|
1080 |
def launch(self, share=False):
|
1081 |
in_dir = root / 'log/gs/ckpts/rope_scene_1'
|
1082 |
batch_size = 1
|
|
|
27 |
import matplotlib.pyplot as plt
|
28 |
from sklearn.neighbors import NearestNeighbors
|
29 |
import spaces
|
30 |
+
from spaces import zero
|
31 |
+
zero.startup()
|
32 |
|
33 |
def install_cuda_toolkit():
|
34 |
# CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
|
|
|
51 |
install_cuda_toolkit()
|
52 |
|
53 |
gs_path = Path(__file__).parent / "src/third-party/diff-gaussian-rasterization-w-depth"
|
54 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", str(gs_path)])
|
55 |
site.main() # re-processes every *.pth in site-packages
|
56 |
importlib.invalidate_caches()
|
57 |
diff_gaussian_rasterization = importlib.import_module("diff_gaussian_rasterization")
|
|
|
113 |
|
114 |
class DynamicsVisualizer:
|
115 |
|
|
|
116 |
def __init__(self):
|
117 |
wp.init()
|
118 |
|
|
|
142 |
cfg.sim.uniform = True
|
143 |
cfg.sim.use_pv = False
|
144 |
|
145 |
+
device = torch.device('cuda')
|
146 |
|
147 |
self.cfg = cfg
|
148 |
self.device = device
|
|
|
155 |
self.dt_base = cfg.sim.dt
|
156 |
self.high_freq_pred = True
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
seed = cfg.seed
|
159 |
random.seed(seed)
|
160 |
np.random.seed(seed)
|
|
|
163 |
torch.backends.cudnn.benchmark = True
|
164 |
|
165 |
self.clear()
|
166 |
+
|
|
|
167 |
def clear(self, clear_params=True):
|
168 |
self.metadata = {}
|
169 |
self.config = {}
|
|
|
196 |
self.material = None
|
197 |
self.friction = None
|
198 |
|
|
|
199 |
def load_scaniverse(self, data_path):
|
200 |
|
201 |
### load splat params
|
|
|
292 |
# self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32)
|
293 |
self.state['gripper_radius'] = cfg.model.gripper_radius
|
294 |
|
|
|
295 |
def load_params(self, params_path, remove_low_opa=True, remove_black=False):
|
296 |
pts, colors, scales, quats, opacities = read_splat(params_path)
|
297 |
|
|
|
357 |
self.state['clip_bound'] = torch.tensor([self.cfg.model.clip_bound], dtype=torch.float32)
|
358 |
self.state['enabled'] = torch.ones(n_particles, dtype=torch.bool)
|
359 |
|
|
|
360 |
def set_camera(self, w, h, intr, w2c=None, R=None, t=None, near=0.01, far=100.0):
|
361 |
if w2c is None:
|
362 |
assert R is not None and t is not None
|
|
|
369 |
}
|
370 |
self.config = {'near': near, 'far': far}
|
371 |
|
|
|
372 |
def load_eef(self, grippers=None, eef_t=None):
|
373 |
assert self.state['prev_key_pos'] is None
|
374 |
|
|
|
391 |
# self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32) + 0.001
|
392 |
self.state['gripper_radius'] = self.cfg.model.gripper_radius
|
393 |
|
|
|
394 |
def load_preprocess_metadata(self, p_x_orig):
|
395 |
cfg = self.cfg
|
396 |
dx = cfg.sim.num_grids[-1]
|
|
|
427 |
'global_translation': global_translation,
|
428 |
}
|
429 |
|
|
|
430 |
@torch.no_grad
|
431 |
def render(self, render_data, cam_id, bg=[0.7, 0.7, 0.7]):
|
432 |
render_data = {k: v.to(self.device) for k, v in render_data.items()}
|
|
|
436 |
im, _, depth, = GaussianRasterizer(raster_settings=cam)(**render_data)
|
437 |
return im, depth
|
438 |
|
|
|
439 |
def knn_relations(self, bones):
|
440 |
k = self.k_rel
|
441 |
knn = NearestNeighbors(n_neighbors=k+1, algorithm='kd_tree').fit(bones.detach().cpu().numpy())
|
|
|
443 |
indices = indices[:, 1:] # exclude self
|
444 |
return indices
|
445 |
|
|
|
446 |
def knn_weights_brute(self, bones, pts):
|
447 |
k = self.k_wgt
|
448 |
dist = torch.norm(pts[:, None] - bones, dim=-1) # (n_pts, n_bones)
|
|
|
455 |
weights_all[torch.arange(pts.shape[0])[:, None], indices] = weights
|
456 |
return weights_all
|
457 |
|
|
|
458 |
def update_camera(self, k, w2c, w=None, h=None, near=0.01, far=100.0):
|
459 |
self.metadata['k'] = k
|
460 |
self.metadata['w2c'] = w2c
|
|
|
465 |
self.config['near'] = near
|
466 |
self.config['far'] = far
|
467 |
|
|
|
468 |
def init_model(self, batch_size, num_steps, num_particles, ckpt_path=None):
|
469 |
self.cfg.sim.num_steps = num_steps
|
470 |
cfg = self.cfg
|
|
|
500 |
self.material = material
|
501 |
self.friction = friction
|
502 |
|
|
|
503 |
def reload_model(self, num_steps): # only change num_steps
|
504 |
self.cfg.sim.num_steps = num_steps
|
505 |
sim = CacheDiffSimWithFrictionBatch(self.cfg, num_steps, 1, self.wp_device, requires_grad=True)
|
506 |
self.sim = sim
|
507 |
|
|
|
508 |
@torch.no_grad
|
509 |
def step(self):
|
510 |
cfg = self.cfg
|
|
|
594 |
self.state['sub_pos'] = None
|
595 |
# self.state['sub_pos_timestamps'] = None
|
596 |
|
|
|
597 |
def preprocess_x(self, p_x): # viewer frame to model frame (not data frame)
|
598 |
R = self.preprocess_metadata['R']
|
599 |
R_viewer = self.preprocess_metadata['R_viewer']
|
|
|
611 |
|
612 |
return p_x
|
613 |
|
|
|
614 |
def preprocess_gripper(self, grippers): # viewer frame to model frame (not data frame)
|
615 |
R = self.preprocess_metadata['R']
|
616 |
R_viewer = self.preprocess_metadata['R_viewer']
|
|
|
623 |
|
624 |
return grippers
|
625 |
|
|
|
626 |
def inverse_preprocess_x(self, p_x): # model frame (not data frame) to viewer frame
|
627 |
R = self.preprocess_metadata['R']
|
628 |
R_viewer = self.preprocess_metadata['R_viewer']
|
|
|
635 |
|
636 |
return p_x
|
637 |
|
|
|
638 |
def inverse_preprocess_gripper(self, grippers): # model frame (not data frame) to viewer frame
|
639 |
R = self.preprocess_metadata['R']
|
640 |
R_viewer = self.preprocess_metadata['R_viewer']
|
|
|
647 |
|
648 |
return grippers
|
649 |
|
|
|
650 |
def rotate(self, params, rot_mat):
|
651 |
scale = np.linalg.norm(rot_mat, axis=1, keepdims=True)
|
652 |
|
|
|
659 |
}
|
660 |
return params
|
661 |
|
|
|
662 |
def preprocess_gs(self, params):
|
663 |
if isinstance(params, dict):
|
664 |
xyz = params['means3D']
|
|
|
708 |
|
709 |
return params
|
710 |
|
|
|
711 |
def preprocess_bg_gs(self):
|
712 |
t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params
|
713 |
g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params
|
|
|
755 |
self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities
|
756 |
self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities
|
757 |
|
|
|
758 |
def update_rendervar(self, rendervar):
|
759 |
p_x = self.state['x']
|
760 |
p_x_viewer = self.inverse_preprocess_x(p_x)
|
|
|
831 |
|
832 |
return rendervar, rendervar_full
|
833 |
|
|
|
834 |
def reset_state(self, params, visualize_image=False, init=False):
|
835 |
xyz_0 = params['means3D']
|
836 |
rgb_0 = params['rgb_colors']
|
|
|
900 |
|
901 |
@spaces.GPU
|
902 |
def reset(self):
|
903 |
+
wp.init()
|
904 |
+
gpus = [int(gpu) for gpu in self.cfg.gpus]
|
905 |
+
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus]
|
906 |
+
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus]
|
907 |
+
device_count = len(torch_devices)
|
908 |
+
assert device_count == 1
|
909 |
+
self.wp_device = wp_devices[0]
|
910 |
+
self.torch_device = torch_devices[0]
|
911 |
+
|
912 |
params = self.preprocess_gs(self.params)
|
913 |
if self.with_bg:
|
914 |
self.preprocess_bg_gs()
|
|
|
949 |
|
950 |
return form_video, form_3dgs_pred
|
951 |
|
|
|
952 |
def run_command(self, unit_command):
|
953 |
|
954 |
os.system('rm -rf ' + str(root / 'log/temp/*'))
|
|
|
1035 |
)
|
1036 |
return form_video, form_3dgs_pred
|
1037 |
|
1038 |
+
@spaces.GPU
|
1039 |
def on_click_run_xplus(self):
|
1040 |
+
wp.init()
|
1041 |
+
gpus = [int(gpu) for gpu in self.cfg.gpus]
|
1042 |
+
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus]
|
1043 |
+
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus]
|
1044 |
+
device_count = len(torch_devices)
|
1045 |
+
assert device_count == 1
|
1046 |
+
self.wp_device = wp_devices[0]
|
1047 |
+
self.torch_device = torch_devices[0]
|
1048 |
return self.run_command([5.0, 0, 0])
|
1049 |
|
1050 |
+
@spaces.GPU
|
1051 |
def on_click_run_xminus(self):
|
1052 |
+
wp.init()
|
1053 |
+
gpus = [int(gpu) for gpu in self.cfg.gpus]
|
1054 |
+
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus]
|
1055 |
+
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus]
|
1056 |
+
device_count = len(torch_devices)
|
1057 |
+
assert device_count == 1
|
1058 |
+
self.wp_device = wp_devices[0]
|
1059 |
+
self.torch_device = torch_devices[0]
|
1060 |
return self.run_command([-5.0, 0, 0])
|
1061 |
|
1062 |
+
@spaces.GPU
|
1063 |
def on_click_run_yplus(self):
|
1064 |
+
wp.init()
|
1065 |
+
gpus = [int(gpu) for gpu in self.cfg.gpus]
|
1066 |
+
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus]
|
1067 |
+
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus]
|
1068 |
+
device_count = len(torch_devices)
|
1069 |
+
assert device_count == 1
|
1070 |
+
self.wp_device = wp_devices[0]
|
1071 |
+
self.torch_device = torch_devices[0]
|
1072 |
return self.run_command([0, 5.0, 0])
|
1073 |
|
1074 |
+
@spaces.GPU
|
1075 |
def on_click_run_yminus(self):
|
1076 |
+
wp.init()
|
1077 |
+
gpus = [int(gpu) for gpu in self.cfg.gpus]
|
1078 |
+
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus]
|
1079 |
+
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus]
|
1080 |
+
device_count = len(torch_devices)
|
1081 |
+
assert device_count == 1
|
1082 |
+
self.wp_device = wp_devices[0]
|
1083 |
+
self.torch_device = torch_devices[0]
|
1084 |
return self.run_command([0, -5.0, 0])
|
1085 |
|
1086 |
+
@spaces.GPU
|
1087 |
def on_click_run_zplus(self):
|
1088 |
+
wp.init()
|
1089 |
+
gpus = [int(gpu) for gpu in self.cfg.gpus]
|
1090 |
+
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus]
|
1091 |
+
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus]
|
1092 |
+
device_count = len(torch_devices)
|
1093 |
+
assert device_count == 1
|
1094 |
+
self.wp_device = wp_devices[0]
|
1095 |
+
self.torch_device = torch_devices[0]
|
1096 |
return self.run_command([0, 0, 5.0])
|
1097 |
|
1098 |
+
@spaces.GPU
|
1099 |
def on_click_run_zminus(self):
|
1100 |
+
wp.init()
|
1101 |
+
gpus = [int(gpu) for gpu in self.cfg.gpus]
|
1102 |
+
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus]
|
1103 |
+
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus]
|
1104 |
+
device_count = len(torch_devices)
|
1105 |
+
assert device_count == 1
|
1106 |
+
self.wp_device = wp_devices[0]
|
1107 |
+
self.torch_device = torch_devices[0]
|
1108 |
return self.run_command([0, 0, -5.0])
|
1109 |
|
|
|
1110 |
def launch(self, share=False):
|
1111 |
in_dir = root / 'log/gs/ckpts/rope_scene_1'
|
1112 |
batch_size = 1
|