|
import gradio as gr |
|
import importlib |
|
import sys |
|
import site |
|
from PIL import Image |
|
from pathlib import Path |
|
from omegaconf import DictConfig, OmegaConf |
|
from tqdm import tqdm, trange |
|
import random |
|
import math |
|
import hydra |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
print(torch.__version__) |
|
print(torch.version.cuda) |
|
|
|
import warp as wp |
|
import glob |
|
from torch.utils.data import DataLoader |
|
import os |
|
import subprocess |
|
import time |
|
import cv2 |
|
import copy |
|
import kornia |
|
import yaml |
|
import matplotlib.pyplot as plt |
|
from sklearn.neighbors import NearestNeighbors |
|
import spaces |
|
from spaces import zero |
|
zero.startup() |
|
|
|
def install_cuda_toolkit(): |
|
|
|
|
|
CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run" |
|
CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL) |
|
subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE]) |
|
subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE]) |
|
subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"]) |
|
|
|
os.environ["CUDA_HOME"] = "/usr/local/cuda" |
|
os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"]) |
|
os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % ( |
|
os.environ["CUDA_HOME"], |
|
"" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"], |
|
) |
|
|
|
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6" |
|
|
|
install_cuda_toolkit() |
|
|
|
gs_path = Path(__file__).parent / "src/third-party/diff-gaussian-rasterization-w-depth" |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", str(gs_path)]) |
|
site.main() |
|
importlib.invalidate_caches() |
|
diff_gaussian_rasterization = importlib.import_module("diff_gaussian_rasterization") |
|
|
|
os.system('pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html') |
|
|
|
|
|
import sys |
|
sys.path.insert(0, str(Path(__file__).parent / "src")) |
|
sys.path.append(str(Path(__file__).parent / "src" / "experiments")) |
|
|
|
from pgnd.sim import Friction, CacheDiffSimWithFrictionBatch, StaticsBatch, CollidersBatch |
|
from pgnd.material import PGNDModel |
|
from pgnd.utils import Logger, get_root, mkdir |
|
from pgnd.ffmpeg import make_video |
|
|
|
from real_world.utils.render_utils import interpolate_motions |
|
from real_world.gs.helpers import setup_camera |
|
from real_world.gs.convert import save_to_splat, read_splat |
|
|
|
from diff_gaussian_rasterization import GaussianRasterizer |
|
from diff_gaussian_rasterization import GaussianRasterizationSettings as Camera |
|
|
|
root = Path(__file__).parent / "src" / "experiments" |
|
|
|
|
|
def quat2mat(quat): |
|
return kornia.geometry.conversions.quaternion_to_rotation_matrix(quat) |
|
|
|
|
|
def mat2quat(mat): |
|
return kornia.geometry.conversions.rotation_matrix_to_quaternion(mat) |
|
|
|
|
|
def fps(x, enabled, n, device, random_start=False): |
|
from dgl.geometry import farthest_point_sampler |
|
assert torch.diff(enabled * 1.0).sum() in [0.0, -1.0] |
|
start_idx = random.randint(0, enabled.sum() - 1) if random_start else 0 |
|
fps_idx = farthest_point_sampler(x[enabled][None], n, start_idx=start_idx)[0] |
|
fps_idx = fps_idx.to(x.device) |
|
return fps_idx |
|
|
|
|
|
class DynamicsVisualizer: |
|
|
|
def __init__(self): |
|
self.width = 640 |
|
self.height = 480 |
|
|
|
best_models = { |
|
'cloth': ['cloth', 'train', 100000, [610, 650]], |
|
'rope': ['rope', 'train', 100000, [651, 691]], |
|
'paperbag': ['paperbag', 'train', 100000, [200, 220]], |
|
'sloth': ['sloth', 'train', 100000, [113, 133]], |
|
'box': ['box', 'train', 100000, [306, 323]], |
|
'bread': ['bread', 'train', 100000, [143, 163]], |
|
} |
|
|
|
task_name = 'rope' |
|
|
|
with open(root / f'log/{best_models[task_name][0]}/{best_models[task_name][1]}/hydra.yaml', 'r') as f: |
|
config = yaml.load(f, Loader=yaml.CLoader) |
|
cfg = OmegaConf.create(config) |
|
|
|
cfg.iteration = best_models[task_name][2] |
|
cfg.start_episode = best_models[task_name][3][0] |
|
cfg.end_episode = best_models[task_name][3][1] |
|
cfg.sim.num_steps = 1000 |
|
cfg.sim.gripper_forcing = False |
|
cfg.sim.uniform = True |
|
cfg.sim.use_pv = False |
|
|
|
device = torch.device('cuda') |
|
|
|
self.cfg = cfg |
|
self.device = device |
|
self.k_rel = 8 |
|
self.k_wgt = 16 |
|
self.with_bg = True |
|
self.render_gripper = True |
|
self.verbose = False |
|
|
|
self.dt_base = cfg.sim.dt |
|
self.high_freq_pred = True |
|
|
|
seed = cfg.seed |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
|
|
|
|
|
|
self.clear() |
|
|
|
def clear(self, clear_params=True): |
|
self.metadata = {} |
|
self.config = {} |
|
if clear_params: |
|
self.params = None |
|
self.state = { |
|
|
|
'x': None, |
|
'v': None, |
|
'x_his': None, |
|
'v_his': None, |
|
'x_pred': None, |
|
'v_pred': None, |
|
'clip_bound': None, |
|
'enabled': None, |
|
|
|
'prev_key_pos': None, |
|
'prev_key_pos_timestamp': None, |
|
'sub_pos': None, |
|
'sub_pos_timestamps': None, |
|
'gripper_radius': None, |
|
} |
|
self.preprocess_metadata = None |
|
self.table_params = None |
|
self.gripper_params = None |
|
|
|
self.sim = None |
|
self.statics = None |
|
self.colliders = None |
|
self.material = None |
|
self.friction = None |
|
|
|
def load_scaniverse(self, data_path): |
|
|
|
|
|
|
|
params_obj = read_splat(data_path / 'object.splat') |
|
params_table = read_splat(data_path / 'table.splat') |
|
params_robot = read_splat(data_path / 'gripper.splat') |
|
|
|
pts, colors, scales, quats, opacities = params_obj |
|
self.params = { |
|
'means3D': torch.from_numpy(pts).to(torch.float32).to(self.device), |
|
'rgb_colors': torch.from_numpy(colors).to(torch.float32).to(self.device), |
|
'log_scales': torch.log(torch.from_numpy(scales).to(torch.float32).to(self.device)), |
|
'unnorm_rotations': torch.from_numpy(quats).to(torch.float32).to(self.device), |
|
'logit_opacities': torch.logit(torch.from_numpy(opacities).to(torch.float32).to(self.device)) |
|
} |
|
|
|
t_pts, t_colors, t_scales, t_quats, t_opacities = params_table |
|
t_pts = torch.tensor(t_pts).to(torch.float32).to(self.device) |
|
t_colors = torch.tensor(t_colors).to(torch.float32).to(self.device) |
|
t_scales = torch.tensor(t_scales).to(torch.float32).to(self.device) |
|
t_quats = torch.tensor(t_quats).to(torch.float32).to(self.device) |
|
t_opacities = torch.tensor(t_opacities).to(torch.float32).to(self.device) |
|
|
|
g_pts, g_colors, g_scales, g_quats, g_opacities = params_robot |
|
g_pts = torch.tensor(g_pts).to(torch.float32).to(self.device) |
|
g_colors = torch.tensor(g_colors).to(torch.float32).to(self.device) |
|
g_scales = torch.tensor(g_scales).to(torch.float32).to(self.device) |
|
g_quats = torch.tensor(g_quats).to(torch.float32).to(self.device) |
|
g_opacities = torch.tensor(g_opacities).to(torch.float32).to(self.device) |
|
|
|
self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities |
|
self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities |
|
|
|
n_particles = self.cfg.sim.n_particles |
|
self.state['clip_bound'] = torch.tensor([self.cfg.model.clip_bound], dtype=torch.float32) |
|
self.state['enabled'] = torch.ones(n_particles, dtype=torch.bool) |
|
|
|
|
|
|
|
cfg = self.cfg |
|
dx = cfg.sim.num_grids[-1] |
|
|
|
p_x = torch.tensor(pts).to(torch.float32).to(self.device) |
|
R = torch.tensor( |
|
[[1, 0, 0], |
|
[0, 0, -1], |
|
[0, 1, 0]] |
|
).to(p_x.device).to(p_x.dtype) |
|
p_x_rotated = p_x @ R.T |
|
|
|
scale = 1.0 |
|
p_x_rotated_scaled = p_x_rotated * scale |
|
|
|
global_translation = torch.tensor([ |
|
0.5 - p_x_rotated_scaled[:, 0].mean(), |
|
dx * (cfg.model.clip_bound + 0.5) - p_x_rotated_scaled[:, 1].min(), |
|
0.5 - p_x_rotated_scaled[:, 2].mean(), |
|
], dtype=p_x_rotated_scaled.dtype, device=p_x_rotated_scaled.device) |
|
|
|
R_viewer = torch.tensor( |
|
[[1, 0, 0], |
|
[0, 0, -1], |
|
[0, 1, 0]] |
|
).to(p_x.device).to(p_x.dtype) |
|
t_viewer = torch.tensor([0, 0, 0]).to(p_x.device).to(p_x.dtype) |
|
|
|
self.preprocess_metadata = { |
|
'R': R, |
|
'R_viewer': R_viewer, |
|
't_viewer': t_viewer, |
|
'scale': scale, |
|
'global_translation': global_translation, |
|
} |
|
|
|
|
|
grippers = np.loadtxt(data_path / 'eef_xyz.txt')[None] |
|
assert grippers.shape == (1, 3) |
|
|
|
if grippers is not None: |
|
grippers = torch.tensor(grippers).to(self.device).to(torch.float32) |
|
|
|
|
|
|
|
R = self.preprocess_metadata['R'] |
|
scale = self.preprocess_metadata['scale'] |
|
global_translation = self.preprocess_metadata['global_translation'] |
|
grippers[:, :3] = grippers[:, :3] @ R.T |
|
grippers[:, :3] = grippers[:, :3] * scale |
|
grippers[:, :3] += global_translation |
|
|
|
assert grippers.shape[0] == 1 |
|
self.state['prev_key_pos'] = grippers[:, :3] |
|
|
|
self.state['gripper_radius'] = cfg.model.gripper_radius |
|
|
|
def load_params(self, params_path, remove_low_opa=True, remove_black=False): |
|
pts, colors, scales, quats, opacities = read_splat(params_path) |
|
|
|
if remove_low_opa: |
|
low_opa_idx = opacities[:, 0] < 0.1 |
|
pts = pts[~low_opa_idx] |
|
colors = colors[~low_opa_idx] |
|
quats = quats[~low_opa_idx] |
|
opacities = opacities[~low_opa_idx] |
|
scales = scales[~low_opa_idx] |
|
|
|
if remove_black: |
|
low_color_idx = colors.sum(axis=-1) < 0.5 |
|
pts = pts[~low_color_idx] |
|
colors = colors[~low_color_idx] |
|
quats = quats[~low_color_idx] |
|
opacities = opacities[~low_color_idx] |
|
scales = scales[~low_color_idx] |
|
|
|
self.params = { |
|
'means3D': torch.from_numpy(pts).to(torch.float32).to(self.device), |
|
'rgb_colors': torch.from_numpy(colors).to(torch.float32).to(self.device), |
|
'log_scales': torch.log(torch.from_numpy(scales).to(torch.float32).to(self.device)), |
|
'unnorm_rotations': torch.from_numpy(quats).to(torch.float32).to(self.device), |
|
'logit_opacities': torch.logit(torch.from_numpy(opacities).to(torch.float32).to(self.device)) |
|
} |
|
|
|
table_splat = root / 'log/gs/ckpts/table.splat' |
|
sphere_splat = root / 'log/gs/ckpts/sphere.splat' |
|
gripper_splat = root / 'log/gs/ckpts/gripper.splat' |
|
|
|
table_params = read_splat(table_splat) |
|
|
|
|
|
|
|
t_pts, t_colors, t_scales, t_quats, t_opacities = table_params |
|
t_pts = torch.tensor(t_pts).to(torch.float32).to(self.device) |
|
t_colors = torch.tensor(t_colors).to(torch.float32).to(self.device) |
|
t_scales = torch.tensor(t_scales).to(torch.float32).to(self.device) |
|
t_quats = torch.tensor(t_quats).to(torch.float32).to(self.device) |
|
t_opacities = torch.tensor(t_opacities).to(torch.float32).to(self.device) |
|
|
|
|
|
t_pts = t_pts + torch.tensor([0, 0, 0.02]).to(torch.float32).to(self.device) |
|
|
|
|
|
gripper_params = read_splat(gripper_splat) |
|
|
|
g_pts, g_colors, g_scales, g_quats, g_opacities = gripper_params |
|
g_pts = torch.tensor(g_pts).to(torch.float32).to(self.device) |
|
g_colors = torch.tensor(g_colors).to(torch.float32).to(self.device) |
|
g_scales = torch.tensor(g_scales).to(torch.float32).to(self.device) |
|
g_quats = torch.tensor(g_quats).to(torch.float32).to(self.device) |
|
g_opacities = torch.tensor(g_opacities).to(torch.float32).to(self.device) |
|
|
|
|
|
|
|
self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities |
|
self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities |
|
|
|
|
|
n_particles = self.cfg.sim.n_particles |
|
self.state['clip_bound'] = torch.tensor([self.cfg.model.clip_bound], dtype=torch.float32) |
|
self.state['enabled'] = torch.ones(n_particles, dtype=torch.bool) |
|
|
|
def set_camera(self, w, h, intr, w2c=None, R=None, t=None, near=0.01, far=100.0): |
|
if w2c is None: |
|
assert R is not None and t is not None |
|
w2c = Rt_to_w2c(R, t) |
|
self.metadata = { |
|
'w': w, |
|
'h': h, |
|
'k': intr, |
|
'w2c': w2c, |
|
} |
|
self.config = {'near': near, 'far': far} |
|
|
|
def load_eef(self, grippers=None, eef_t=None): |
|
assert self.state['prev_key_pos'] is None |
|
|
|
if grippers is not None: |
|
grippers = torch.tensor(grippers).to(self.device).to(torch.float32) |
|
eef_t = torch.tensor(eef_t).to(self.device).to(torch.float32) |
|
grippers[:, :3] = grippers[:, :3] + eef_t |
|
|
|
|
|
|
|
R = self.preprocess_metadata['R'] |
|
scale = self.preprocess_metadata['scale'] |
|
global_translation = self.preprocess_metadata['global_translation'] |
|
grippers[:, :3] = grippers[:, :3] @ R.T |
|
grippers[:, :3] = grippers[:, :3] * scale |
|
grippers[:, :3] += global_translation |
|
|
|
assert grippers.shape[0] == 1 |
|
self.state['prev_key_pos'] = grippers[:, :3] |
|
|
|
self.state['gripper_radius'] = self.cfg.model.gripper_radius |
|
|
|
def load_preprocess_metadata(self, p_x_orig): |
|
cfg = self.cfg |
|
dx = cfg.sim.num_grids[-1] |
|
|
|
p_x_orig = p_x_orig.to(self.device) |
|
R = torch.tensor( |
|
[[1, 0, 0], |
|
[0, 0, -1], |
|
[0, 1, 0]] |
|
).to(p_x_orig.device).to(p_x_orig.dtype) |
|
p_x_orig_rotated = torch.einsum('nij,jk->nik', p_x_orig, R.T) |
|
|
|
scale = 1.0 |
|
p_x_orig_rotated_scaled = p_x_orig_rotated * scale |
|
|
|
global_translation = torch.tensor([ |
|
0.5 - p_x_orig_rotated_scaled[:, :, 0].mean(), |
|
dx * (cfg.model.clip_bound + 0.5) - p_x_orig_rotated_scaled[:, :, 1].min(), |
|
0.5 - p_x_orig_rotated_scaled[:, :, 2].mean(), |
|
], dtype=p_x_orig_rotated_scaled.dtype, device=p_x_orig_rotated_scaled.device) |
|
|
|
R_viewer = torch.tensor( |
|
[[1, 0, 0], |
|
[0, 0, -1], |
|
[0, 1, 0]] |
|
).to(p_x_orig.device).to(p_x_orig.dtype) |
|
t_viewer = torch.tensor([0, 0, 0]).to(p_x_orig.device).to(p_x_orig.dtype) |
|
|
|
self.preprocess_metadata = { |
|
'R': R, |
|
'R_viewer': R_viewer, |
|
't_viewer': t_viewer, |
|
'scale': scale, |
|
'global_translation': global_translation, |
|
} |
|
|
|
@torch.no_grad |
|
def render(self, render_data, cam_id, bg=[0.7, 0.7, 0.7]): |
|
render_data = {k: v.to(self.device) for k, v in render_data.items()} |
|
w, h = self.metadata['w'], self.metadata['h'] |
|
k, w2c = self.metadata['k'], self.metadata['w2c'] |
|
cam = setup_camera(w, h, k, w2c, self.config['near'], self.config['far'], bg) |
|
im, _, depth, = GaussianRasterizer(raster_settings=cam)(**render_data) |
|
return im, depth |
|
|
|
def knn_relations(self, bones): |
|
k = self.k_rel |
|
knn = NearestNeighbors(n_neighbors=k+1, algorithm='kd_tree').fit(bones.detach().cpu().numpy()) |
|
_, indices = knn.kneighbors(bones.detach().cpu().numpy()) |
|
indices = indices[:, 1:] |
|
return indices |
|
|
|
def knn_weights_brute(self, bones, pts): |
|
k = self.k_wgt |
|
dist = torch.norm(pts[:, None] - bones, dim=-1) |
|
_, indices = torch.topk(dist, k, dim=-1, largest=False) |
|
bones_selected = bones[indices] |
|
dist = torch.norm(bones_selected - pts[:, None], dim=-1) |
|
weights = 1 / (dist + 1e-6) |
|
weights = weights / weights.sum(dim=-1, keepdim=True) |
|
weights_all = torch.zeros((pts.shape[0], bones.shape[0]), device=pts.device) |
|
weights_all[torch.arange(pts.shape[0])[:, None], indices] = weights |
|
return weights_all |
|
|
|
def update_camera(self, k, w2c, w=None, h=None, near=0.01, far=100.0): |
|
self.metadata['k'] = k |
|
self.metadata['w2c'] = w2c |
|
if w is not None: |
|
self.metadata['w'] = w |
|
if h is not None: |
|
self.metadata['h'] = h |
|
self.config['near'] = near |
|
self.config['far'] = far |
|
|
|
def init_model(self, batch_size, num_steps, num_particles, ckpt_path=None): |
|
self.cfg.sim.num_steps = num_steps |
|
cfg = self.cfg |
|
|
|
sim = CacheDiffSimWithFrictionBatch(cfg, num_steps, batch_size, self.wp_device, requires_grad=True) |
|
|
|
statics = StaticsBatch() |
|
statics.init(shape=(batch_size, num_particles), device=self.wp_device) |
|
statics.update_clip_bound(self.state['clip_bound']) |
|
statics.update_enabled(self.state['enabled'][None]) |
|
colliders = CollidersBatch() |
|
colliders.init(shape=(batch_size, cfg.sim.num_grippers), device=self.wp_device) |
|
|
|
self.sim = sim |
|
self.statics = statics |
|
self.colliders = colliders |
|
|
|
|
|
ckpt_path = root / 'log/rope/train/ckpt/100000.pt' |
|
ckpt = torch.load(ckpt_path, map_location=self.torch_device) |
|
|
|
material: nn.Module = PGNDModel(cfg) |
|
material.to(self.torch_device) |
|
material.load_state_dict(ckpt['material']) |
|
material.requires_grad_(False) |
|
material.eval() |
|
|
|
if 'friction' in ckpt: |
|
friction = ckpt['friction']['mu'].reshape(-1, 1) |
|
else: |
|
friction = torch.tensor(cfg.model.friction.value, device=self.torch_device).reshape(-1, 1) |
|
|
|
self.material = material |
|
self.friction = friction |
|
|
|
def reload_model(self, num_steps): |
|
self.cfg.sim.num_steps = num_steps |
|
sim = CacheDiffSimWithFrictionBatch(self.cfg, num_steps, 1, self.wp_device, requires_grad=True) |
|
self.sim = sim |
|
|
|
@torch.no_grad |
|
def step(self): |
|
cfg = self.cfg |
|
batch_size = 1 |
|
num_steps = 1 |
|
num_particles = cfg.sim.n_particles |
|
|
|
|
|
self.state['x_his'] = torch.cat([self.state['x_his'][1:], self.state['x'][None]], dim=0) |
|
self.state['v_his'] = torch.cat([self.state['v_his'][1:], self.state['v'][None]], dim=0) |
|
self.state['x'] = self.state['x_pred'].clone() |
|
self.state['v'] = self.state['v_pred'].clone() |
|
|
|
eef_xyz_key = self.state['prev_key_pos'] |
|
eef_xyz_sub = self.state['sub_pos'] |
|
|
|
if eef_xyz_sub is None: |
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eef_xyz_key_next = eef_xyz_sub[-1] |
|
eef_v = (eef_xyz_key_next - eef_xyz_key) / cfg.sim.dt |
|
if self.verbose: |
|
print('delta_t:', np.round(cfg.sim.dt, 4)) |
|
print('eef_xyz_key_next:', eef_xyz_key_next.cpu().numpy().tolist()) |
|
print('eef_xyz_key:', eef_xyz_key.cpu().numpy().tolist()) |
|
print('v:', eef_v.cpu().numpy().tolist()) |
|
|
|
|
|
self.reload_model(num_steps) |
|
|
|
|
|
if cfg.sim.num_grippers > 0: |
|
grippers = torch.zeros((batch_size, cfg.sim.num_grippers, 15), device=self.torch_device) |
|
eef_quat = torch.tensor([1, 0, 0, 0], dtype=torch.float32, device=self.torch_device).repeat(batch_size, cfg.sim.num_grippers, 1) |
|
eef_quat_vel = torch.zeros((batch_size, cfg.sim.num_grippers, 3), dtype=torch.float32, device=self.torch_device) |
|
eef_gripper = torch.zeros((batch_size, cfg.sim.num_grippers), dtype=torch.float32, device=self.torch_device) |
|
grippers[:, :, :3] = eef_xyz_key |
|
grippers[:, :, 3:6] = eef_v |
|
grippers[:, :, 6:10] = eef_quat |
|
grippers[:, :, 10:13] = eef_quat_vel |
|
grippers[:, :, 13] = cfg.model.gripper_radius |
|
grippers[:, :, 14] = eef_gripper |
|
self.colliders.initialize_grippers(grippers) |
|
|
|
x = self.state['x'].clone()[None].repeat(batch_size, 1, 1) |
|
v = self.state['v'].clone()[None].repeat(batch_size, 1, 1) |
|
x_his = self.state['x_his'].permute(1, 0, 2).clone() |
|
assert x_his.shape[0] == num_particles |
|
x_his = x_his.reshape(num_particles, -1)[None].repeat(batch_size, 1, 1) |
|
v_his = self.state['v_his'].permute(1, 0, 2).clone() |
|
assert v_his.shape[0] == num_particles |
|
v_his = v_his.reshape(num_particles, -1)[None].repeat(batch_size, 1, 1) |
|
enabled = self.state['enabled'].clone().to(self.torch_device)[None].repeat(batch_size, 1) |
|
|
|
for t in range(num_steps): |
|
x_in = x.clone() |
|
pred = self.material(x, v, x_his, v_his, enabled) |
|
|
|
|
|
|
|
|
|
|
|
|
|
x, v = self.sim(self.statics, self.colliders, t, x, v, self.friction, pred) |
|
|
|
|
|
x_pred = x[0].clone() |
|
v_pred = v[0].clone() |
|
self.state['x_pred'] = x_pred |
|
self.state['v_pred'] = v_pred |
|
|
|
|
|
|
|
self.state['prev_key_pos'] = eef_xyz_key_next |
|
|
|
self.state['sub_pos'] = None |
|
|
|
|
|
def preprocess_x(self, p_x): |
|
R = self.preprocess_metadata['R'] |
|
R_viewer = self.preprocess_metadata['R_viewer'] |
|
t_viewer = self.preprocess_metadata['t_viewer'] |
|
scale = self.preprocess_metadata['scale'] |
|
global_translation = self.preprocess_metadata['global_translation'] |
|
|
|
|
|
p_x = (p_x - t_viewer) @ R_viewer |
|
|
|
|
|
|
|
|
|
|
|
|
|
return p_x |
|
|
|
def preprocess_gripper(self, grippers): |
|
R = self.preprocess_metadata['R'] |
|
R_viewer = self.preprocess_metadata['R_viewer'] |
|
t_viewer = self.preprocess_metadata['t_viewer'] |
|
scale = self.preprocess_metadata['scale'] |
|
global_translation = self.preprocess_metadata['global_translation'] |
|
|
|
|
|
grippers[:, :3] = grippers[:, :3] @ R_viewer |
|
|
|
return grippers |
|
|
|
def inverse_preprocess_x(self, p_x): |
|
R = self.preprocess_metadata['R'] |
|
R_viewer = self.preprocess_metadata['R_viewer'] |
|
t_viewer = self.preprocess_metadata['t_viewer'] |
|
scale = self.preprocess_metadata['scale'] |
|
global_translation = self.preprocess_metadata['global_translation'] |
|
|
|
|
|
p_x = p_x @ R_viewer.T + t_viewer |
|
|
|
return p_x |
|
|
|
def inverse_preprocess_gripper(self, grippers): |
|
R = self.preprocess_metadata['R'] |
|
R_viewer = self.preprocess_metadata['R_viewer'] |
|
t_viewer = self.preprocess_metadata['t_viewer'] |
|
scale = self.preprocess_metadata['scale'] |
|
global_translation = self.preprocess_metadata['global_translation'] |
|
|
|
|
|
grippers[:, :3] = grippers[:, :3] @ R_viewer.T + t_viewer |
|
|
|
return grippers |
|
|
|
def rotate(self, params, rot_mat): |
|
scale = np.linalg.norm(rot_mat, axis=1, keepdims=True) |
|
|
|
params = { |
|
'means3D': pts, |
|
'rgb_colors': params['rgb_colors'], |
|
'log_scales': params['log_scales'], |
|
'unnorm_rotations': quats, |
|
'logit_opacities': params['logit_opacities'], |
|
} |
|
return params |
|
|
|
def preprocess_gs(self, params): |
|
if isinstance(params, dict): |
|
xyz = params['means3D'] |
|
rgb = params['rgb_colors'] |
|
quat = torch.nn.functional.normalize(params['unnorm_rotations']) |
|
opa = torch.sigmoid(params['logit_opacities']) |
|
scales = torch.exp(params['log_scales']) |
|
else: |
|
assert isinstance(params, tuple) |
|
xyz, rgb, quat, opa, scales = params |
|
|
|
quat = torch.nn.functional.normalize(quat, dim=-1) |
|
|
|
|
|
R = self.preprocess_metadata['R'] |
|
R_viewer = self.preprocess_metadata['R_viewer'] |
|
scale = self.preprocess_metadata['scale'] |
|
global_translation = self.preprocess_metadata['global_translation'] |
|
|
|
mat = quat2mat(quat) |
|
mat = R @ mat |
|
xyz = xyz @ R.T |
|
xyz = xyz * scale |
|
xyz += global_translation |
|
quat = mat2quat(mat) |
|
scales = scales * scale |
|
|
|
|
|
|
|
xyz = xyz @ R_viewer.T |
|
quat = mat2quat(R_viewer @ quat2mat(quat)) |
|
|
|
t_viewer = -xyz.mean(dim=0) |
|
t_viewer[2] = 0 |
|
xyz += t_viewer |
|
print('Overwriting t_viewer to be the planar mean of the object') |
|
self.preprocess_metadata['t_viewer'] = t_viewer |
|
|
|
if isinstance(params, dict): |
|
params['means3D'] = xyz |
|
params['rgb_colors'] = rgb |
|
params['unnorm_rotations'] = quat |
|
params['logit_opacities'] = opa |
|
params['log_scales'] = torch.log(scales) |
|
else: |
|
params = xyz, rgb, quat, opa, scales |
|
|
|
return params |
|
|
|
def preprocess_bg_gs(self): |
|
t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params |
|
g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params |
|
|
|
|
|
g_pts_tip_z = g_pts[:, 2].max() |
|
g_pts_tip_mask = (g_pts[:, 2] > g_pts_tip_z - 0.04) & (g_pts[:, 2] < g_pts_tip_z) |
|
|
|
R = self.preprocess_metadata['R'] |
|
R_viewer = self.preprocess_metadata['R_viewer'] |
|
t_viewer = self.preprocess_metadata['t_viewer'] |
|
scale = self.preprocess_metadata['scale'] |
|
global_translation = self.preprocess_metadata['global_translation'] |
|
|
|
t_mat = quat2mat(t_quats) |
|
t_mat = R @ t_mat |
|
t_pts = t_pts @ R.T |
|
t_pts = t_pts * scale |
|
t_pts += global_translation |
|
t_quats = mat2quat(t_mat) |
|
t_scales = t_scales * scale |
|
|
|
t_pts = t_pts @ R_viewer.T |
|
t_quats = mat2quat(R_viewer @ quat2mat(t_quats)) |
|
t_pts += t_viewer |
|
|
|
g_mat = quat2mat(g_quats) |
|
g_mat = R @ g_mat |
|
g_pts = g_pts @ R.T |
|
g_pts = g_pts * scale |
|
g_pts += global_translation |
|
g_quats = mat2quat(g_mat) |
|
g_scales = g_scales * scale |
|
|
|
g_pts = g_pts @ R_viewer.T |
|
g_quats = mat2quat(R_viewer @ quat2mat(g_quats)) |
|
g_pts += t_viewer |
|
|
|
|
|
g_pts_tip = g_pts[g_pts_tip_mask] |
|
g_pts_tip_mean_xy = g_pts_tip[:, :2].mean(dim=0) |
|
g_pts_translation = torch.tensor([-g_pts_tip_mean_xy[0], -g_pts_tip_mean_xy[1], -0.23]).to(torch.float32).to(self.device) |
|
g_pts = g_pts + g_pts_translation |
|
|
|
self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities |
|
self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities |
|
|
|
def update_rendervar(self, rendervar): |
|
p_x = self.state['x'] |
|
p_x_viewer = self.inverse_preprocess_x(p_x) |
|
|
|
p_x_pred = self.state['x_pred'] |
|
p_x_pred_viewer = self.inverse_preprocess_x(p_x_pred) |
|
|
|
xyz = rendervar['means3D'] |
|
rgb = rendervar['colors_precomp'] |
|
quat = rendervar['rotations'] |
|
opa = rendervar['opacities'] |
|
scales = rendervar['scales'] |
|
|
|
relations = self.knn_relations(p_x_viewer) |
|
weights = self.knn_weights_brute(p_x_viewer, xyz) |
|
xyz, quat, _ = interpolate_motions( |
|
bones=p_x_viewer, |
|
motions=p_x_pred_viewer - p_x_viewer, |
|
relations=relations, |
|
weights=weights, |
|
xyz=xyz, |
|
quat=quat, |
|
) |
|
|
|
|
|
quat = torch.nn.functional.normalize(quat, dim=-1) |
|
|
|
rendervar = { |
|
'means3D': xyz, |
|
'colors_precomp': rgb, |
|
'rotations': quat, |
|
'opacities': opa, |
|
'scales': scales, |
|
'means2D': torch.zeros_like(xyz), |
|
} |
|
|
|
if self.with_bg: |
|
t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params |
|
|
|
|
|
xyz = torch.cat([xyz, t_pts], dim=0) |
|
rgb = torch.cat([rgb, t_colors], dim=0) |
|
quat = torch.cat([quat, t_quats], dim=0) |
|
opa = torch.cat([opa, t_opacities], dim=0) |
|
scales = torch.cat([scales, t_scales], dim=0) |
|
|
|
if self.render_gripper: |
|
g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params |
|
|
|
|
|
g_pts = g_pts + self.inverse_preprocess_gripper(self.state['prev_key_pos'][None].clone())[0] |
|
|
|
|
|
xyz = torch.cat([xyz, g_pts], dim=0) |
|
rgb = torch.cat([rgb, g_colors], dim=0) |
|
quat = torch.cat([quat, g_quats], dim=0) |
|
opa = torch.cat([opa, g_opacities], dim=0) |
|
scales = torch.cat([scales, g_scales], dim=0) |
|
|
|
|
|
quat = torch.nn.functional.normalize(quat, dim=-1) |
|
|
|
rendervar_full = { |
|
'means3D': xyz, |
|
'colors_precomp': rgb, |
|
'rotations': quat, |
|
'opacities': opa, |
|
'scales': scales, |
|
'means2D': torch.zeros_like(xyz), |
|
} |
|
|
|
else: |
|
rendervar_full = rendervar |
|
|
|
return rendervar, rendervar_full |
|
|
|
def reset_state(self, params, visualize_image=False, init=False): |
|
xyz_0 = params['means3D'] |
|
rgb_0 = params['rgb_colors'] |
|
quat_0 = torch.nn.functional.normalize(params['unnorm_rotations']) |
|
opa_0 = torch.sigmoid(params['logit_opacities']) |
|
scales_0 = torch.exp(params['log_scales']) |
|
|
|
rendervar_init = { |
|
'means3D': xyz_0, |
|
'colors_precomp': rgb_0, |
|
'rotations': quat_0, |
|
'opacities': opa_0, |
|
'scales': scales_0, |
|
'means2D': torch.zeros_like(xyz_0), |
|
} |
|
|
|
w = self.width |
|
h = self.height |
|
center = (0, 0, 0.1) |
|
distance = 0.7 |
|
elevation = 20 |
|
azimuth = 180.0 |
|
target = np.array(center) |
|
theta = 90 + azimuth |
|
z = distance * math.sin(math.radians(elevation)) |
|
y = math.cos(math.radians(theta)) * distance * math.cos(math.radians(elevation)) |
|
x = math.sin(math.radians(theta)) * distance * math.cos(math.radians(elevation)) |
|
origin = target + np.array([x, y, z]) |
|
|
|
look_at = target - origin |
|
look_at /= np.linalg.norm(look_at) |
|
up = np.array([0.0, 0.0, 1.0]) |
|
right = np.cross(look_at, up) |
|
right /= np.linalg.norm(right) |
|
up = np.cross(right, look_at) |
|
w2c = np.eye(4) |
|
w2c[:3, 0] = right |
|
w2c[:3, 1] = -up |
|
w2c[:3, 2] = look_at |
|
w2c[:3, 3] = origin |
|
w2c = np.linalg.inv(w2c) |
|
|
|
k = np.array( |
|
[[w / 2 * 1.0, 0., w / 2], |
|
[0., w / 2 * 1.0, h / 2], |
|
[0., 0., 1.]], |
|
) |
|
self.update_camera(k, w2c, w, h) |
|
|
|
n_particles = self.cfg.sim.n_particles |
|
downsample_indices = fps(xyz_0, torch.ones_like(xyz_0[:, 0]).to(torch.bool), n_particles, self.torch_device) |
|
p_x_viewer = xyz_0[downsample_indices] |
|
p_x = self.preprocess_x(p_x_viewer) |
|
|
|
self.state['x'] = p_x |
|
self.state['v'] = torch.zeros_like(p_x) |
|
self.state['x_his'] = p_x[None].repeat(self.cfg.sim.n_history, 1, 1) |
|
self.state['v_his'] = torch.zeros_like(p_x[None].repeat(self.cfg.sim.n_history, 1, 1)) |
|
self.state['x_pred'] = p_x |
|
self.state['v_pred'] = torch.zeros_like(p_x) |
|
|
|
rendervar_init, rendervar_init_full = self.update_rendervar(rendervar_init) |
|
im, depth = self.render(rendervar_init_full, 0, bg=[0.0, 0.0, 0.0]) |
|
im_vis = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8) |
|
|
|
return rendervar_init |
|
|
|
@spaces.GPU |
|
def reset(self): |
|
wp.init() |
|
gpus = [int(gpu) for gpu in self.cfg.gpus] |
|
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] |
|
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] |
|
device_count = len(torch_devices) |
|
assert device_count == 1 |
|
self.wp_device = wp_devices[0] |
|
self.torch_device = torch_devices[0] |
|
|
|
in_dir = root / 'log/gs/ckpts/rope_scene_1' |
|
batch_size = 1 |
|
num_steps = 1 |
|
num_particles = self.cfg.sim.n_particles |
|
self.load_scaniverse(in_dir) |
|
self.init_model(batch_size, num_steps, num_particles, ckpt_path=None) |
|
|
|
params = self.preprocess_gs(self.params) |
|
if self.with_bg: |
|
self.preprocess_bg_gs() |
|
rendervar = self.reset_state(params, visualize_image=False, init=True) |
|
rendervar, rendervar_full = self.update_rendervar(rendervar) |
|
self.rendervar = rendervar |
|
|
|
im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0]) |
|
im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy() |
|
|
|
cv2.imwrite(str(root / 'log/temp_init/0000.png'), cv2.cvtColor(im_show, cv2.COLOR_RGB2BGR)) |
|
|
|
make_video(root / 'log/temp_init', root / f'log/gs/temp/form_video_init.mp4', '%04d.png', 1) |
|
|
|
gs_pred = save_to_splat( |
|
rendervar_full['means3D'].cpu().numpy(), |
|
rendervar_full['colors_precomp'].cpu().numpy(), |
|
rendervar_full['scales'].cpu().numpy(), |
|
rendervar_full['rotations'].cpu().numpy(), |
|
rendervar_full['opacities'].cpu().numpy(), |
|
root / 'log/gs/temp/gs_pred.splat', |
|
rot_rev=True, |
|
) |
|
|
|
form_video = gr.Video( |
|
label='Predicted video', |
|
value=root / f'log/gs/temp/form_video.mp4', |
|
format='mp4', |
|
width=self.width, |
|
height=self.height, |
|
) |
|
form_3dgs_pred = gr.Model3D( |
|
label='Predicted Gaussian Splats', |
|
height=self.height, |
|
value=root / 'log/gs/temp/gs_pred.splat', |
|
clear_color=[0, 0, 0, 0], |
|
) |
|
|
|
return form_video, form_3dgs_pred |
|
|
|
def run_command(self, unit_command): |
|
|
|
os.system('rm -rf ' + str(root / 'log/temp/*')) |
|
|
|
|
|
for i in range(15): |
|
dt = 0.1 |
|
command = torch.tensor([unit_command]).to(self.device).to(torch.float32) |
|
command = self.preprocess_gripper(command) |
|
|
|
|
|
if self.verbose: |
|
print('command:', command.cpu().numpy().tolist()) |
|
|
|
assert self.state['sub_pos'] is None |
|
|
|
if self.state['sub_pos'] is None: |
|
eef_xyz_latest = self.state['prev_key_pos'] |
|
|
|
|
|
else: |
|
eef_xyz_latest = self.state['sub_pos'][-1] |
|
|
|
|
|
eef_xyz_updated = eef_xyz_latest + command * dt * 0.01 |
|
|
|
if self.state['sub_pos'] is None: |
|
self.state['sub_pos'] = eef_xyz_updated[None] |
|
|
|
else: |
|
self.state['sub_pos'] = torch.cat([self.state['sub_pos'], eef_xyz_updated[None]], dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.step() |
|
rendervar, rendervar_full = self.update_rendervar(self.rendervar) |
|
self.rendervar = rendervar |
|
im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0]) |
|
im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy() |
|
|
|
|
|
cv2.imwrite(str(root / f'log/temp/{i:04}.png'), cv2.cvtColor(im_show, cv2.COLOR_RGB2BGR)) |
|
|
|
|
|
self.state['v'] *= 0.0 |
|
self.state['x'] = self.state['x_pred'].clone() |
|
self.state['x_his'] = self.state['x'][None].repeat(self.cfg.sim.n_history, 1, 1) |
|
self.state['v_his'] *= 0.0 |
|
self.state['v_pred'] *= 0.0 |
|
|
|
make_video(root / 'log/temp', root / f'log/gs/temp/form_video.mp4', '%04d.png', 5) |
|
|
|
form_video = gr.Video( |
|
label='Predicted video', |
|
value=root / f'log/gs/temp/form_video.mp4', |
|
format='mp4', |
|
width=self.width, |
|
height=self.height, |
|
) |
|
|
|
im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0]) |
|
im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy() |
|
|
|
gs_pred = save_to_splat( |
|
rendervar_full['means3D'].cpu().numpy(), |
|
rendervar_full['colors_precomp'].cpu().numpy(), |
|
rendervar_full['scales'].cpu().numpy(), |
|
rendervar_full['rotations'].cpu().numpy(), |
|
rendervar_full['opacities'].cpu().numpy(), |
|
root / 'log/gs/temp/gs_pred.splat', |
|
rot_rev=True, |
|
) |
|
form_3dgs_pred = gr.Model3D( |
|
label='Predicted Gaussian Splats', |
|
height=self.height, |
|
value=root / 'log/gs/temp/gs_pred.splat', |
|
clear_color=[0, 0, 0, 0], |
|
) |
|
return form_video, form_3dgs_pred |
|
|
|
@spaces.GPU |
|
def on_click_run_xplus(self): |
|
wp.init() |
|
gpus = [int(gpu) for gpu in self.cfg.gpus] |
|
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] |
|
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] |
|
device_count = len(torch_devices) |
|
assert device_count == 1 |
|
self.wp_device = wp_devices[0] |
|
self.torch_device = torch_devices[0] |
|
return self.run_command([5.0, 0, 0]) |
|
|
|
@spaces.GPU |
|
def on_click_run_xminus(self): |
|
wp.init() |
|
gpus = [int(gpu) for gpu in self.cfg.gpus] |
|
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] |
|
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] |
|
device_count = len(torch_devices) |
|
assert device_count == 1 |
|
self.wp_device = wp_devices[0] |
|
self.torch_device = torch_devices[0] |
|
return self.run_command([-5.0, 0, 0]) |
|
|
|
@spaces.GPU |
|
def on_click_run_yplus(self): |
|
wp.init() |
|
gpus = [int(gpu) for gpu in self.cfg.gpus] |
|
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] |
|
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] |
|
device_count = len(torch_devices) |
|
assert device_count == 1 |
|
self.wp_device = wp_devices[0] |
|
self.torch_device = torch_devices[0] |
|
return self.run_command([0, 5.0, 0]) |
|
|
|
@spaces.GPU |
|
def on_click_run_yminus(self): |
|
wp.init() |
|
gpus = [int(gpu) for gpu in self.cfg.gpus] |
|
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] |
|
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] |
|
device_count = len(torch_devices) |
|
assert device_count == 1 |
|
self.wp_device = wp_devices[0] |
|
self.torch_device = torch_devices[0] |
|
return self.run_command([0, -5.0, 0]) |
|
|
|
@spaces.GPU |
|
def on_click_run_zplus(self): |
|
wp.init() |
|
gpus = [int(gpu) for gpu in self.cfg.gpus] |
|
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] |
|
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] |
|
device_count = len(torch_devices) |
|
assert device_count == 1 |
|
self.wp_device = wp_devices[0] |
|
self.torch_device = torch_devices[0] |
|
return self.run_command([0, 0, 5.0]) |
|
|
|
@spaces.GPU |
|
def on_click_run_zminus(self): |
|
wp.init() |
|
gpus = [int(gpu) for gpu in self.cfg.gpus] |
|
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] |
|
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] |
|
device_count = len(torch_devices) |
|
assert device_count == 1 |
|
self.wp_device = wp_devices[0] |
|
self.torch_device = torch_devices[0] |
|
return self.run_command([0, 0, -5.0]) |
|
|
|
def launch(self, share=False): |
|
|
|
with gr.Blocks() as app: |
|
|
|
with gr.Row(): |
|
gr.Markdown("# Particle-Grid Neural Dynamics for Learning Deformable Object Models from RGB-D Videos") |
|
|
|
with gr.Row(): |
|
gr.Markdown('### Project page: [https://kywind.github.io/pgnd](https://kywind.github.io/pgnd)') |
|
|
|
with gr.Row(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Column(scale=2): |
|
form_video = gr.Video( |
|
label='Predicted video', |
|
value=None, |
|
format='mp4', |
|
width=self.width, |
|
height=self.height, |
|
) |
|
|
|
with gr.Column(scale=2): |
|
form_3dgs_pred = gr.Model3D( |
|
label='Predicted Gaussians', |
|
height=self.height, |
|
value=None, |
|
clear_color=[0, 0, 0, 0], |
|
) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
with gr.Row(): |
|
run_reset = gr.Button("Reset") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
run_xminus = gr.Button("x-") |
|
with gr.Column(): |
|
run_xplus = gr.Button("x+") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
run_yminus = gr.Button("y-") |
|
with gr.Column(): |
|
run_yplus = gr.Button("y+") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
run_zminus = gr.Button("z-") |
|
with gr.Column(): |
|
run_zplus = gr.Button("z+") |
|
|
|
with gr.Column(scale=2): |
|
_ = gr.Button(visible=False) |
|
|
|
|
|
run_reset.click(self.reset, |
|
inputs=[], |
|
outputs=[form_video, form_3dgs_pred]) |
|
|
|
run_xplus.click(self.on_click_run_xplus, |
|
inputs=[], |
|
outputs=[form_video, form_3dgs_pred]) |
|
|
|
run_xminus.click(self.on_click_run_xminus, |
|
inputs=[], |
|
outputs=[form_video, form_3dgs_pred]) |
|
|
|
run_yplus.click(self.on_click_run_yplus, |
|
inputs=[], |
|
outputs=[form_video, form_3dgs_pred]) |
|
|
|
run_yminus.click(self.on_click_run_yminus, |
|
inputs=[], |
|
outputs=[form_video, form_3dgs_pred]) |
|
|
|
run_zplus.click(self.on_click_run_zplus, |
|
inputs=[], |
|
outputs=[form_video, form_3dgs_pred]) |
|
|
|
run_zminus.click(self.on_click_run_zminus, |
|
inputs=[], |
|
outputs=[form_video, form_3dgs_pred]) |
|
|
|
app.launch(share=share) |
|
|
|
|
|
if __name__ == '__main__': |
|
visualizer = DynamicsVisualizer() |
|
visualizer.launch(share=True) |
|
|