import gradio as gr import importlib import sys import site from PIL import Image from pathlib import Path from omegaconf import DictConfig, OmegaConf from tqdm import tqdm, trange import random import math import hydra import numpy as np import torch import torch.nn as nn print(torch.__version__) print(torch.version.cuda) # import torch.backends.cudnn import warp as wp import glob from torch.utils.data import DataLoader import os import subprocess import time import cv2 import copy import kornia import yaml import matplotlib.pyplot as plt from sklearn.neighbors import NearestNeighbors import spaces from spaces import zero zero.startup() def install_cuda_toolkit(): # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run" # CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run" CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run" CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL) subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE]) subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE]) subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"]) os.environ["CUDA_HOME"] = "/usr/local/cuda" os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"]) os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % ( os.environ["CUDA_HOME"], "" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"], ) # Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6" install_cuda_toolkit() gs_path = Path(__file__).parent / "src/third-party/diff-gaussian-rasterization-w-depth" subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", str(gs_path)]) site.main() # re-processes every *.pth in site-packages importlib.invalidate_caches() diff_gaussian_rasterization = importlib.import_module("diff_gaussian_rasterization") os.system('pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html') # os.system('conda install conda-forge::ffmpeg') import sys sys.path.insert(0, str(Path(__file__).parent / "src")) sys.path.append(str(Path(__file__).parent / "src" / "experiments")) from pgnd.sim import Friction, CacheDiffSimWithFrictionBatch, StaticsBatch, CollidersBatch from pgnd.material import PGNDModel from pgnd.utils import Logger, get_root, mkdir from pgnd.ffmpeg import make_video from real_world.utils.render_utils import interpolate_motions from real_world.gs.helpers import setup_camera from real_world.gs.convert import save_to_splat, read_splat from diff_gaussian_rasterization import GaussianRasterizer from diff_gaussian_rasterization import GaussianRasterizationSettings as Camera root = Path(__file__).parent / "src" / "experiments" def quat2mat(quat): return kornia.geometry.conversions.quaternion_to_rotation_matrix(quat) def mat2quat(mat): return kornia.geometry.conversions.rotation_matrix_to_quaternion(mat) def fps(x, enabled, n, device, random_start=False): from dgl.geometry import farthest_point_sampler assert torch.diff(enabled * 1.0).sum() in [0.0, -1.0] start_idx = random.randint(0, enabled.sum() - 1) if random_start else 0 fps_idx = farthest_point_sampler(x[enabled][None], n, start_idx=start_idx)[0] fps_idx = fps_idx.to(x.device) return fps_idx class DynamicsVisualizer: def __init__(self): self.width = 640 self.height = 480 best_models = { 'cloth': ['cloth', 'train', 100000, [610, 650]], 'rope': ['rope', 'train', 100000, [651, 691]], 'paperbag': ['paperbag', 'train', 100000, [200, 220]], 'sloth': ['sloth', 'train', 100000, [113, 133]], 'box': ['box', 'train', 100000, [306, 323]], 'bread': ['bread', 'train', 100000, [143, 163]], } task_name = 'rope' with open(root / f'log/{best_models[task_name][0]}/{best_models[task_name][1]}/hydra.yaml', 'r') as f: config = yaml.load(f, Loader=yaml.CLoader) cfg = OmegaConf.create(config) cfg.iteration = best_models[task_name][2] cfg.start_episode = best_models[task_name][3][0] cfg.end_episode = best_models[task_name][3][1] cfg.sim.num_steps = 1000 cfg.sim.gripper_forcing = False cfg.sim.uniform = True cfg.sim.use_pv = False device = torch.device('cuda') self.cfg = cfg self.device = device self.k_rel = 8 # knn for relations self.k_wgt = 16 # knn for weights self.with_bg = True self.render_gripper = True self.verbose = False self.dt_base = cfg.sim.dt self.high_freq_pred = True seed = cfg.seed random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # torch.autograd.set_detect_anomaly(True) # torch.backends.cudnn.benchmark = True self.clear() def clear(self, clear_params=True): self.metadata = {} self.config = {} if clear_params: self.params = None self.state = { # object 'x': None, 'v': None, 'x_his': None, 'v_his': None, 'x_pred': None, 'v_pred': None, 'clip_bound': None, 'enabled': None, # robot 'prev_key_pos': None, 'prev_key_pos_timestamp': None, 'sub_pos': None, # filling in between key positions 'sub_pos_timestamps': None, 'gripper_radius': None, } self.preprocess_metadata = None self.table_params = None self.gripper_params = None self.sim = None self.statics = None self.colliders = None self.material = None self.friction = None def load_scaniverse(self, data_path): ### load splat params params_obj = read_splat(data_path / 'object.splat') params_table = read_splat(data_path / 'table.splat') params_robot = read_splat(data_path / 'gripper.splat') pts, colors, scales, quats, opacities = params_obj self.params = { 'means3D': torch.from_numpy(pts).to(torch.float32).to(self.device), 'rgb_colors': torch.from_numpy(colors).to(torch.float32).to(self.device), 'log_scales': torch.log(torch.from_numpy(scales).to(torch.float32).to(self.device)), 'unnorm_rotations': torch.from_numpy(quats).to(torch.float32).to(self.device), 'logit_opacities': torch.logit(torch.from_numpy(opacities).to(torch.float32).to(self.device)) } t_pts, t_colors, t_scales, t_quats, t_opacities = params_table t_pts = torch.tensor(t_pts).to(torch.float32).to(self.device) t_colors = torch.tensor(t_colors).to(torch.float32).to(self.device) t_scales = torch.tensor(t_scales).to(torch.float32).to(self.device) t_quats = torch.tensor(t_quats).to(torch.float32).to(self.device) t_opacities = torch.tensor(t_opacities).to(torch.float32).to(self.device) g_pts, g_colors, g_scales, g_quats, g_opacities = params_robot g_pts = torch.tensor(g_pts).to(torch.float32).to(self.device) g_colors = torch.tensor(g_colors).to(torch.float32).to(self.device) g_scales = torch.tensor(g_scales).to(torch.float32).to(self.device) g_quats = torch.tensor(g_quats).to(torch.float32).to(self.device) g_opacities = torch.tensor(g_opacities).to(torch.float32).to(self.device) self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities # data frame self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities # data frame n_particles = self.cfg.sim.n_particles self.state['clip_bound'] = torch.tensor([self.cfg.model.clip_bound], dtype=torch.float32) self.state['enabled'] = torch.ones(n_particles, dtype=torch.bool) ### load preprocess metadata cfg = self.cfg dx = cfg.sim.num_grids[-1] p_x = torch.tensor(pts).to(torch.float32).to(self.device) R = torch.tensor( [[1, 0, 0], [0, 0, -1], [0, 1, 0]] ).to(p_x.device).to(p_x.dtype) p_x_rotated = p_x @ R.T scale = 1.0 p_x_rotated_scaled = p_x_rotated * scale global_translation = torch.tensor([ 0.5 - p_x_rotated_scaled[:, 0].mean(), dx * (cfg.model.clip_bound + 0.5) - p_x_rotated_scaled[:, 1].min(), 0.5 - p_x_rotated_scaled[:, 2].mean(), ], dtype=p_x_rotated_scaled.dtype, device=p_x_rotated_scaled.device) R_viewer = torch.tensor( [[1, 0, 0], [0, 0, -1], [0, 1, 0]] ).to(p_x.device).to(p_x.dtype) t_viewer = torch.tensor([0, 0, 0]).to(p_x.device).to(p_x.dtype) self.preprocess_metadata = { 'R': R, 'R_viewer': R_viewer, 't_viewer': t_viewer, 'scale': scale, 'global_translation': global_translation, } ### load eef grippers = np.loadtxt(data_path / 'eef_xyz.txt')[None] assert grippers.shape == (1, 3) if grippers is not None: grippers = torch.tensor(grippers).to(self.device).to(torch.float32) # transform # data frame to model frame R = self.preprocess_metadata['R'] scale = self.preprocess_metadata['scale'] global_translation = self.preprocess_metadata['global_translation'] grippers[:, :3] = grippers[:, :3] @ R.T grippers[:, :3] = grippers[:, :3] * scale grippers[:, :3] += global_translation assert grippers.shape[0] == 1 self.state['prev_key_pos'] = grippers[:, :3] # (1, 3) # self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32) self.state['gripper_radius'] = cfg.model.gripper_radius def load_params(self, params_path, remove_low_opa=True, remove_black=False): pts, colors, scales, quats, opacities = read_splat(params_path) if remove_low_opa: low_opa_idx = opacities[:, 0] < 0.1 pts = pts[~low_opa_idx] colors = colors[~low_opa_idx] quats = quats[~low_opa_idx] opacities = opacities[~low_opa_idx] scales = scales[~low_opa_idx] if remove_black: low_color_idx = colors.sum(axis=-1) < 0.5 pts = pts[~low_color_idx] colors = colors[~low_color_idx] quats = quats[~low_color_idx] opacities = opacities[~low_color_idx] scales = scales[~low_color_idx] self.params = { 'means3D': torch.from_numpy(pts).to(torch.float32).to(self.device), 'rgb_colors': torch.from_numpy(colors).to(torch.float32).to(self.device), 'log_scales': torch.log(torch.from_numpy(scales).to(torch.float32).to(self.device)), 'unnorm_rotations': torch.from_numpy(quats).to(torch.float32).to(self.device), 'logit_opacities': torch.logit(torch.from_numpy(opacities).to(torch.float32).to(self.device)) } table_splat = root / 'log/gs/ckpts/table.splat' sphere_splat = root / 'log/gs/ckpts/sphere.splat' gripper_splat = root / 'log/gs/ckpts/gripper.splat' # gripper_new.splat table_params = read_splat(table_splat) # numpy ## add table and gripper # add table t_pts, t_colors, t_scales, t_quats, t_opacities = table_params t_pts = torch.tensor(t_pts).to(torch.float32).to(self.device) t_colors = torch.tensor(t_colors).to(torch.float32).to(self.device) t_scales = torch.tensor(t_scales).to(torch.float32).to(self.device) t_quats = torch.tensor(t_quats).to(torch.float32).to(self.device) t_opacities = torch.tensor(t_opacities).to(torch.float32).to(self.device) # add table pos t_pts = t_pts + torch.tensor([0, 0, 0.02]).to(torch.float32).to(self.device) # add gripper gripper_params = read_splat(gripper_splat) # numpy g_pts, g_colors, g_scales, g_quats, g_opacities = gripper_params g_pts = torch.tensor(g_pts).to(torch.float32).to(self.device) g_colors = torch.tensor(g_colors).to(torch.float32).to(self.device) g_scales = torch.tensor(g_scales).to(torch.float32).to(self.device) g_quats = torch.tensor(g_quats).to(torch.float32).to(self.device) g_opacities = torch.tensor(g_opacities).to(torch.float32).to(self.device) # we do not do the gripper translation now because this will center the gripper in the data frame but not the viewer frame self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities # data frame self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities # data frame # load other info n_particles = self.cfg.sim.n_particles self.state['clip_bound'] = torch.tensor([self.cfg.model.clip_bound], dtype=torch.float32) self.state['enabled'] = torch.ones(n_particles, dtype=torch.bool) def set_camera(self, w, h, intr, w2c=None, R=None, t=None, near=0.01, far=100.0): if w2c is None: assert R is not None and t is not None w2c = Rt_to_w2c(R, t) self.metadata = { 'w': w, 'h': h, 'k': intr, 'w2c': w2c, } self.config = {'near': near, 'far': far} def load_eef(self, grippers=None, eef_t=None): assert self.state['prev_key_pos'] is None if grippers is not None: grippers = torch.tensor(grippers).to(self.device).to(torch.float32) eef_t = torch.tensor(eef_t).to(self.device).to(torch.float32) grippers[:, :3] = grippers[:, :3] + eef_t # transform # data frame to model frame R = self.preprocess_metadata['R'] scale = self.preprocess_metadata['scale'] global_translation = self.preprocess_metadata['global_translation'] grippers[:, :3] = grippers[:, :3] @ R.T grippers[:, :3] = grippers[:, :3] * scale grippers[:, :3] += global_translation assert grippers.shape[0] == 1 self.state['prev_key_pos'] = grippers[:, :3] # (1, 3) # self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32) + 0.001 self.state['gripper_radius'] = self.cfg.model.gripper_radius def load_preprocess_metadata(self, p_x_orig): cfg = self.cfg dx = cfg.sim.num_grids[-1] p_x_orig = p_x_orig.to(self.device) R = torch.tensor( [[1, 0, 0], [0, 0, -1], [0, 1, 0]] ).to(p_x_orig.device).to(p_x_orig.dtype) p_x_orig_rotated = torch.einsum('nij,jk->nik', p_x_orig, R.T) scale = 1.0 p_x_orig_rotated_scaled = p_x_orig_rotated * scale global_translation = torch.tensor([ 0.5 - p_x_orig_rotated_scaled[:, :, 0].mean(), dx * (cfg.model.clip_bound + 0.5) - p_x_orig_rotated_scaled[:, :, 1].min(), 0.5 - p_x_orig_rotated_scaled[:, :, 2].mean(), ], dtype=p_x_orig_rotated_scaled.dtype, device=p_x_orig_rotated_scaled.device) R_viewer = torch.tensor( [[1, 0, 0], [0, 0, -1], [0, 1, 0]] ).to(p_x_orig.device).to(p_x_orig.dtype) t_viewer = torch.tensor([0, 0, 0]).to(p_x_orig.device).to(p_x_orig.dtype) self.preprocess_metadata = { 'R': R, 'R_viewer': R_viewer, 't_viewer': t_viewer, 'scale': scale, 'global_translation': global_translation, } @torch.no_grad def render(self, render_data, cam_id, bg=[0.7, 0.7, 0.7]): render_data = {k: v.to(self.device) for k, v in render_data.items()} w, h = self.metadata['w'], self.metadata['h'] k, w2c = self.metadata['k'], self.metadata['w2c'] cam = setup_camera(w, h, k, w2c, self.config['near'], self.config['far'], bg) im, _, depth, = GaussianRasterizer(raster_settings=cam)(**render_data) return im, depth def knn_relations(self, bones): k = self.k_rel knn = NearestNeighbors(n_neighbors=k+1, algorithm='kd_tree').fit(bones.detach().cpu().numpy()) _, indices = knn.kneighbors(bones.detach().cpu().numpy()) # (N, k) indices = indices[:, 1:] # exclude self return indices def knn_weights_brute(self, bones, pts): k = self.k_wgt dist = torch.norm(pts[:, None] - bones, dim=-1) # (n_pts, n_bones) _, indices = torch.topk(dist, k, dim=-1, largest=False) bones_selected = bones[indices] # (N, k, 3) dist = torch.norm(bones_selected - pts[:, None], dim=-1) # (N, k) weights = 1 / (dist + 1e-6) weights = weights / weights.sum(dim=-1, keepdim=True) # (N, k) weights_all = torch.zeros((pts.shape[0], bones.shape[0]), device=pts.device) weights_all[torch.arange(pts.shape[0])[:, None], indices] = weights return weights_all def update_camera(self, k, w2c, w=None, h=None, near=0.01, far=100.0): self.metadata['k'] = k self.metadata['w2c'] = w2c if w is not None: self.metadata['w'] = w if h is not None: self.metadata['h'] = h self.config['near'] = near self.config['far'] = far def init_model(self, batch_size, num_steps, num_particles, ckpt_path=None): self.cfg.sim.num_steps = num_steps cfg = self.cfg sim = CacheDiffSimWithFrictionBatch(cfg, num_steps, batch_size, self.wp_device, requires_grad=True) statics = StaticsBatch() statics.init(shape=(batch_size, num_particles), device=self.wp_device) statics.update_clip_bound(self.state['clip_bound']) statics.update_enabled(self.state['enabled'][None]) colliders = CollidersBatch() colliders.init(shape=(batch_size, cfg.sim.num_grippers), device=self.wp_device) self.sim = sim self.statics = statics self.colliders = colliders # load ckpt ckpt_path = root / 'log/rope/train/ckpt/100000.pt' ckpt = torch.load(ckpt_path, map_location=self.torch_device) material: nn.Module = PGNDModel(cfg) material.to(self.torch_device) material.load_state_dict(ckpt['material']) material.requires_grad_(False) material.eval() if 'friction' in ckpt: friction = ckpt['friction']['mu'].reshape(-1, 1) else: friction = torch.tensor(cfg.model.friction.value, device=self.torch_device).reshape(-1, 1) self.material = material self.friction = friction def reload_model(self, num_steps): # only change num_steps self.cfg.sim.num_steps = num_steps sim = CacheDiffSimWithFrictionBatch(self.cfg, num_steps, 1, self.wp_device, requires_grad=True) self.sim = sim @torch.no_grad def step(self): cfg = self.cfg batch_size = 1 num_steps = 1 num_particles = cfg.sim.n_particles # update state by previous prediction self.state['x_his'] = torch.cat([self.state['x_his'][1:], self.state['x'][None]], dim=0) self.state['v_his'] = torch.cat([self.state['v_his'][1:], self.state['v'][None]], dim=0) self.state['x'] = self.state['x_pred'].clone() self.state['v'] = self.state['v_pred'].clone() eef_xyz_key = self.state['prev_key_pos'] # (1, 3), model frame eef_xyz_sub = self.state['sub_pos'] # (T, 1, 3), model frame if eef_xyz_sub is None: return # eef_xyz_key_timestamp = self.state['prev_key_pos_timestamp'] # eef_xyz_sub_timestamps = self.state['sub_pos_timestamps'] # assert eef_xyz_key_timestamp.item() > 0 # delta_t = (eef_xyz_sub_timestamps[-1] - eef_xyz_key_timestamp).item() # if (not self.high_freq_pred) and delta_t < self.dt_base * 0.9: # return # cfg.sim.dt = delta_t eef_xyz_key_next = eef_xyz_sub[-1] # (1, 3), model frame eef_v = (eef_xyz_key_next - eef_xyz_key) / cfg.sim.dt if self.verbose: print('delta_t:', np.round(cfg.sim.dt, 4)) print('eef_xyz_key_next:', eef_xyz_key_next.cpu().numpy().tolist()) print('eef_xyz_key:', eef_xyz_key.cpu().numpy().tolist()) print('v:', eef_v.cpu().numpy().tolist()) # load model, sim, statics, colliders self.reload_model(num_steps) # initialize colliders if cfg.sim.num_grippers > 0: grippers = torch.zeros((batch_size, cfg.sim.num_grippers, 15), device=self.torch_device) eef_quat = torch.tensor([1, 0, 0, 0], dtype=torch.float32, device=self.torch_device).repeat(batch_size, cfg.sim.num_grippers, 1) # (B, G, 4) eef_quat_vel = torch.zeros((batch_size, cfg.sim.num_grippers, 3), dtype=torch.float32, device=self.torch_device) eef_gripper = torch.zeros((batch_size, cfg.sim.num_grippers), dtype=torch.float32, device=self.torch_device) grippers[:, :, :3] = eef_xyz_key grippers[:, :, 3:6] = eef_v grippers[:, :, 6:10] = eef_quat grippers[:, :, 10:13] = eef_quat_vel grippers[:, :, 13] = cfg.model.gripper_radius grippers[:, :, 14] = eef_gripper self.colliders.initialize_grippers(grippers) x = self.state['x'].clone()[None].repeat(batch_size, 1, 1) v = self.state['v'].clone()[None].repeat(batch_size, 1, 1) x_his = self.state['x_his'].permute(1, 0, 2).clone() assert x_his.shape[0] == num_particles x_his = x_his.reshape(num_particles, -1)[None].repeat(batch_size, 1, 1) v_his = self.state['v_his'].permute(1, 0, 2).clone() assert v_his.shape[0] == num_particles v_his = v_his.reshape(num_particles, -1)[None].repeat(batch_size, 1, 1) enabled = self.state['enabled'].clone().to(self.torch_device)[None].repeat(batch_size, 1) for t in range(num_steps): x_in = x.clone() pred = self.material(x, v, x_his, v_his, enabled) # x_his = torch.cat([x_his.reshape(batch_size, num_particles, -1, 3)[:, :, 1:], x[:, :, None].detach()], dim=2) # v_his = torch.cat([v_his.reshape(batch_size, num_particles, -1, 3)[:, :, 1:], v[:, :, None].detach()], dim=2) # x_his = x_his.reshape(batch_size, num_particles, -1) # v_his = v_his.reshape(batch_size, num_particles, -1) x, v = self.sim(self.statics, self.colliders, t, x, v, self.friction, pred) # calculate new x_pred, v_pred, eef_xyz_key and eef_xyz_sub x_pred = x[0].clone() v_pred = v[0].clone() self.state['x_pred'] = x_pred self.state['v_pred'] = v_pred # self.state['x_his'] = x_his[0].reshape(num_particles, self.cfg.sim.n_history, 3).permute(1, 0, 2).clone() # self.state['v_his'] = v_his[0].reshape(num_particles, self.cfg.sim.n_history, 3).permute(1, 0, 2).clone() self.state['prev_key_pos'] = eef_xyz_key_next # self.state['prev_key_pos_timestamp'] = eef_xyz_sub_timestamps[-1] self.state['sub_pos'] = None # self.state['sub_pos_timestamps'] = None def preprocess_x(self, p_x): # viewer frame to model frame (not data frame) R = self.preprocess_metadata['R'] R_viewer = self.preprocess_metadata['R_viewer'] t_viewer = self.preprocess_metadata['t_viewer'] scale = self.preprocess_metadata['scale'] global_translation = self.preprocess_metadata['global_translation'] # viewer frame to model frame p_x = (p_x - t_viewer) @ R_viewer # model frame to data frame # p_x -= global_translation # p_x = p_x / scale # p_x = p_x @ torch.linalg.inv(R).T return p_x def preprocess_gripper(self, grippers): # viewer frame to model frame (not data frame) R = self.preprocess_metadata['R'] R_viewer = self.preprocess_metadata['R_viewer'] t_viewer = self.preprocess_metadata['t_viewer'] scale = self.preprocess_metadata['scale'] global_translation = self.preprocess_metadata['global_translation'] # viewer frame to model frame grippers[:, :3] = grippers[:, :3] @ R_viewer return grippers def inverse_preprocess_x(self, p_x): # model frame (not data frame) to viewer frame R = self.preprocess_metadata['R'] R_viewer = self.preprocess_metadata['R_viewer'] t_viewer = self.preprocess_metadata['t_viewer'] scale = self.preprocess_metadata['scale'] global_translation = self.preprocess_metadata['global_translation'] # model frame to viewer frame p_x = p_x @ R_viewer.T + t_viewer return p_x def inverse_preprocess_gripper(self, grippers): # model frame (not data frame) to viewer frame R = self.preprocess_metadata['R'] R_viewer = self.preprocess_metadata['R_viewer'] t_viewer = self.preprocess_metadata['t_viewer'] scale = self.preprocess_metadata['scale'] global_translation = self.preprocess_metadata['global_translation'] # model frame to viewer frame grippers[:, :3] = grippers[:, :3] @ R_viewer.T + t_viewer return grippers def rotate(self, params, rot_mat): scale = np.linalg.norm(rot_mat, axis=1, keepdims=True) params = { 'means3D': pts, 'rgb_colors': params['rgb_colors'], 'log_scales': params['log_scales'], 'unnorm_rotations': quats, 'logit_opacities': params['logit_opacities'], } return params def preprocess_gs(self, params): if isinstance(params, dict): xyz = params['means3D'] rgb = params['rgb_colors'] quat = torch.nn.functional.normalize(params['unnorm_rotations']) opa = torch.sigmoid(params['logit_opacities']) scales = torch.exp(params['log_scales']) else: assert isinstance(params, tuple) xyz, rgb, quat, opa, scales = params quat = torch.nn.functional.normalize(quat, dim=-1) # transform R = self.preprocess_metadata['R'] R_viewer = self.preprocess_metadata['R_viewer'] scale = self.preprocess_metadata['scale'] global_translation = self.preprocess_metadata['global_translation'] mat = quat2mat(quat) mat = R @ mat xyz = xyz @ R.T xyz = xyz * scale xyz += global_translation quat = mat2quat(mat) scales = scales * scale # viewer-specific transform (flip y and z) # model frame to viewer frame xyz = xyz @ R_viewer.T quat = mat2quat(R_viewer @ quat2mat(quat)) t_viewer = -xyz.mean(dim=0) t_viewer[2] = 0 xyz += t_viewer print('Overwriting t_viewer to be the planar mean of the object') self.preprocess_metadata['t_viewer'] = t_viewer if isinstance(params, dict): params['means3D'] = xyz params['rgb_colors'] = rgb params['unnorm_rotations'] = quat params['logit_opacities'] = opa params['log_scales'] = torch.log(scales) else: params = xyz, rgb, quat, opa, scales return params def preprocess_bg_gs(self): t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params # identify tip first g_pts_tip_z = g_pts[:, 2].max() g_pts_tip_mask = (g_pts[:, 2] > g_pts_tip_z - 0.04) & (g_pts[:, 2] < g_pts_tip_z) R = self.preprocess_metadata['R'] R_viewer = self.preprocess_metadata['R_viewer'] t_viewer = self.preprocess_metadata['t_viewer'] scale = self.preprocess_metadata['scale'] global_translation = self.preprocess_metadata['global_translation'] t_mat = quat2mat(t_quats) t_mat = R @ t_mat t_pts = t_pts @ R.T t_pts = t_pts * scale t_pts += global_translation t_quats = mat2quat(t_mat) t_scales = t_scales * scale t_pts = t_pts @ R_viewer.T t_quats = mat2quat(R_viewer @ quat2mat(t_quats)) t_pts += t_viewer g_mat = quat2mat(g_quats) g_mat = R @ g_mat g_pts = g_pts @ R.T g_pts = g_pts * scale g_pts += global_translation g_quats = mat2quat(g_mat) g_scales = g_scales * scale g_pts = g_pts @ R_viewer.T g_quats = mat2quat(R_viewer @ quat2mat(g_quats)) g_pts += t_viewer # TODO: center gripper in the viewer frame g_pts_tip = g_pts[g_pts_tip_mask] g_pts_tip_mean_xy = g_pts_tip[:, :2].mean(dim=0) g_pts_translation = torch.tensor([-g_pts_tip_mean_xy[0], -g_pts_tip_mean_xy[1], -0.23]).to(torch.float32).to(self.device) g_pts = g_pts + g_pts_translation self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities def update_rendervar(self, rendervar): p_x = self.state['x'] p_x_viewer = self.inverse_preprocess_x(p_x) p_x_pred = self.state['x_pred'] p_x_pred_viewer = self.inverse_preprocess_x(p_x_pred) xyz = rendervar['means3D'] rgb = rendervar['colors_precomp'] quat = rendervar['rotations'] opa = rendervar['opacities'] scales = rendervar['scales'] relations = self.knn_relations(p_x_viewer) weights = self.knn_weights_brute(p_x_viewer, xyz) xyz, quat, _ = interpolate_motions( bones=p_x_viewer, motions=p_x_pred_viewer - p_x_viewer, relations=relations, weights=weights, xyz=xyz, quat=quat, ) # normalize quat = torch.nn.functional.normalize(quat, dim=-1) rendervar = { 'means3D': xyz, 'colors_precomp': rgb, 'rotations': quat, 'opacities': opa, 'scales': scales, 'means2D': torch.zeros_like(xyz), } if self.with_bg: t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params # merge xyz = torch.cat([xyz, t_pts], dim=0) rgb = torch.cat([rgb, t_colors], dim=0) quat = torch.cat([quat, t_quats], dim=0) opa = torch.cat([opa, t_opacities], dim=0) scales = torch.cat([scales, t_scales], dim=0) if self.render_gripper: g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params # add gripper pos g_pts = g_pts + self.inverse_preprocess_gripper(self.state['prev_key_pos'][None].clone())[0] # merge xyz = torch.cat([xyz, g_pts], dim=0) rgb = torch.cat([rgb, g_colors], dim=0) quat = torch.cat([quat, g_quats], dim=0) opa = torch.cat([opa, g_opacities], dim=0) scales = torch.cat([scales, g_scales], dim=0) # normalize quat = torch.nn.functional.normalize(quat, dim=-1) rendervar_full = { 'means3D': xyz, 'colors_precomp': rgb, 'rotations': quat, 'opacities': opa, 'scales': scales, 'means2D': torch.zeros_like(xyz), } else: rendervar_full = rendervar return rendervar, rendervar_full def reset_state(self, params, visualize_image=False, init=False): xyz_0 = params['means3D'] rgb_0 = params['rgb_colors'] quat_0 = torch.nn.functional.normalize(params['unnorm_rotations']) opa_0 = torch.sigmoid(params['logit_opacities']) scales_0 = torch.exp(params['log_scales']) rendervar_init = { 'means3D': xyz_0, 'colors_precomp': rgb_0, 'rotations': quat_0, 'opacities': opa_0, 'scales': scales_0, 'means2D': torch.zeros_like(xyz_0), } # before preprocess w = self.width h = self.height center = (0, 0, 0.1) distance = 0.7 elevation = 20 azimuth = 180.0 target = np.array(center) theta = 90 + azimuth z = distance * math.sin(math.radians(elevation)) y = math.cos(math.radians(theta)) * distance * math.cos(math.radians(elevation)) x = math.sin(math.radians(theta)) * distance * math.cos(math.radians(elevation)) origin = target + np.array([x, y, z]) look_at = target - origin look_at /= np.linalg.norm(look_at) up = np.array([0.0, 0.0, 1.0]) right = np.cross(look_at, up) right /= np.linalg.norm(right) up = np.cross(right, look_at) w2c = np.eye(4) w2c[:3, 0] = right w2c[:3, 1] = -up w2c[:3, 2] = look_at w2c[:3, 3] = origin w2c = np.linalg.inv(w2c) k = np.array( [[w / 2 * 1.0, 0., w / 2], [0., w / 2 * 1.0, h / 2], [0., 0., 1.]], ) self.update_camera(k, w2c, w, h) n_particles = self.cfg.sim.n_particles downsample_indices = fps(xyz_0, torch.ones_like(xyz_0[:, 0]).to(torch.bool), n_particles, self.torch_device) p_x_viewer = xyz_0[downsample_indices] p_x = self.preprocess_x(p_x_viewer) self.state['x'] = p_x self.state['v'] = torch.zeros_like(p_x) self.state['x_his'] = p_x[None].repeat(self.cfg.sim.n_history, 1, 1) self.state['v_his'] = torch.zeros_like(p_x[None].repeat(self.cfg.sim.n_history, 1, 1)) self.state['x_pred'] = p_x self.state['v_pred'] = torch.zeros_like(p_x) rendervar_init, rendervar_init_full = self.update_rendervar(rendervar_init) im, depth = self.render(rendervar_init_full, 0, bg=[0.0, 0.0, 0.0]) im_vis = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8) return rendervar_init @spaces.GPU def reset(self): wp.init() gpus = [int(gpu) for gpu in self.cfg.gpus] wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] device_count = len(torch_devices) assert device_count == 1 self.wp_device = wp_devices[0] self.torch_device = torch_devices[0] in_dir = root / 'log/gs/ckpts/rope_scene_1' batch_size = 1 num_steps = 1 num_particles = self.cfg.sim.n_particles self.load_scaniverse(in_dir) self.init_model(batch_size, num_steps, num_particles, ckpt_path=None) params = self.preprocess_gs(self.params) if self.with_bg: self.preprocess_bg_gs() rendervar = self.reset_state(params, visualize_image=False, init=True) rendervar, rendervar_full = self.update_rendervar(rendervar) self.rendervar = rendervar im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0]) im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy() cv2.imwrite(str(root / 'log/temp_init/0000.png'), cv2.cvtColor(im_show, cv2.COLOR_RGB2BGR)) make_video(root / 'log/temp_init', root / f'log/gs/temp/form_video_init.mp4', '%04d.png', 1) gs_pred = save_to_splat( rendervar_full['means3D'].cpu().numpy(), rendervar_full['colors_precomp'].cpu().numpy(), rendervar_full['scales'].cpu().numpy(), rendervar_full['rotations'].cpu().numpy(), rendervar_full['opacities'].cpu().numpy(), root / 'log/gs/temp/gs_pred.splat', rot_rev=True, ) form_video = gr.Video( label='Predicted video', value=root / f'log/gs/temp/form_video.mp4', format='mp4', width=self.width, height=self.height, ) form_3dgs_pred = gr.Model3D( label='Predicted Gaussian Splats', height=self.height, value=root / 'log/gs/temp/gs_pred.splat', clear_color=[0, 0, 0, 0], ) return form_video, form_3dgs_pred def run_command(self, unit_command): os.system('rm -rf ' + str(root / 'log/temp/*')) # im_list = [] for i in range(15): dt = 0.1 # 100ms command = torch.tensor([unit_command]).to(self.device).to(torch.float32) # 5cm/s command = self.preprocess_gripper(command) # command_timestamp = torch.tensor([self.state['prev_key_pos_timestamp'] + (i+1) * dt]).to(self.device).to(torch.float32) # print(command_timestamp) if self.verbose: print('command:', command.cpu().numpy().tolist()) assert self.state['sub_pos'] is None if self.state['sub_pos'] is None: eef_xyz_latest = self.state['prev_key_pos'] # eef_xyz_timestamp_latest = self.state['prev_key_pos_timestamp'] else: eef_xyz_latest = self.state['sub_pos'][-1] # (1, 3), model frame # eef_xyz_timestamp_latest = self.state['sub_pos_timestamps'][-1].item() eef_xyz_updated = eef_xyz_latest + command * dt * 0.01 # cm to m if self.state['sub_pos'] is None: self.state['sub_pos'] = eef_xyz_updated[None] # self.state['sub_pos_timestamps'] = command_timestamp else: self.state['sub_pos'] = torch.cat([self.state['sub_pos'], eef_xyz_updated[None]], dim=0) # self.state['sub_pos_timestamps'] = torch.cat([self.state['sub_pos_timestamps'], command_timestamp], dim=0) # if self.state['sub_pos'] is None: # eef_xyz = self.state['prev_key_pos'] # else: # eef_xyz = self.state['sub_pos'][-1] # (1, 3), model frame # if self.verbose: # print(eef_xyz.cpu().numpy().tolist(), end=' ') self.step() rendervar, rendervar_full = self.update_rendervar(self.rendervar) self.rendervar = rendervar im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0]) im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy() # im_list.append(im_show) cv2.imwrite(str(root / f'log/temp/{i:04}.png'), cv2.cvtColor(im_show, cv2.COLOR_RGB2BGR)) # self.state['prev_key_pos_timestamp'] = self.state['prev_key_pos_timestamp'] + 20 * dt self.state['v'] *= 0.0 self.state['x'] = self.state['x_pred'].clone() self.state['x_his'] = self.state['x'][None].repeat(self.cfg.sim.n_history, 1, 1) self.state['v_his'] *= 0.0 self.state['v_pred'] *= 0.0 make_video(root / 'log/temp', root / f'log/gs/temp/form_video.mp4', '%04d.png', 5) form_video = gr.Video( label='Predicted video', value=root / f'log/gs/temp/form_video.mp4', format='mp4', width=self.width, height=self.height, ) im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0]) im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy() gs_pred = save_to_splat( rendervar_full['means3D'].cpu().numpy(), rendervar_full['colors_precomp'].cpu().numpy(), rendervar_full['scales'].cpu().numpy(), rendervar_full['rotations'].cpu().numpy(), rendervar_full['opacities'].cpu().numpy(), root / 'log/gs/temp/gs_pred.splat', rot_rev=True, ) form_3dgs_pred = gr.Model3D( label='Predicted Gaussian Splats', height=self.height, value=root / 'log/gs/temp/gs_pred.splat', clear_color=[0, 0, 0, 0], ) return form_video, form_3dgs_pred @spaces.GPU def on_click_run_xplus(self): wp.init() gpus = [int(gpu) for gpu in self.cfg.gpus] wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] device_count = len(torch_devices) assert device_count == 1 self.wp_device = wp_devices[0] self.torch_device = torch_devices[0] return self.run_command([5.0, 0, 0]) @spaces.GPU def on_click_run_xminus(self): wp.init() gpus = [int(gpu) for gpu in self.cfg.gpus] wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] device_count = len(torch_devices) assert device_count == 1 self.wp_device = wp_devices[0] self.torch_device = torch_devices[0] return self.run_command([-5.0, 0, 0]) @spaces.GPU def on_click_run_yplus(self): wp.init() gpus = [int(gpu) for gpu in self.cfg.gpus] wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] device_count = len(torch_devices) assert device_count == 1 self.wp_device = wp_devices[0] self.torch_device = torch_devices[0] return self.run_command([0, 5.0, 0]) @spaces.GPU def on_click_run_yminus(self): wp.init() gpus = [int(gpu) for gpu in self.cfg.gpus] wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] device_count = len(torch_devices) assert device_count == 1 self.wp_device = wp_devices[0] self.torch_device = torch_devices[0] return self.run_command([0, -5.0, 0]) @spaces.GPU def on_click_run_zplus(self): wp.init() gpus = [int(gpu) for gpu in self.cfg.gpus] wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] device_count = len(torch_devices) assert device_count == 1 self.wp_device = wp_devices[0] self.torch_device = torch_devices[0] return self.run_command([0, 0, 5.0]) @spaces.GPU def on_click_run_zminus(self): wp.init() gpus = [int(gpu) for gpu in self.cfg.gpus] wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus] torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus] device_count = len(torch_devices) assert device_count == 1 self.wp_device = wp_devices[0] self.torch_device = torch_devices[0] return self.run_command([0, 0, -5.0]) def launch(self, share=False): with gr.Blocks() as app: with gr.Row(): gr.Markdown("# Particle-Grid Neural Dynamics for Learning Deformable Object Models from RGB-D Videos") with gr.Row(): gr.Markdown('### Project page: [https://kywind.github.io/pgnd](https://kywind.github.io/pgnd)') with gr.Row(): # with gr.Column(scale=2): # form_3dgs_orig = gr.Model3D( # label='Original Gaussian Splats', # value=None, # ) with gr.Column(scale=2): form_video = gr.Video( label='Predicted video', value=None, format='mp4', width=self.width, height=self.height, ) with gr.Column(scale=2): form_3dgs_pred = gr.Model3D( label='Predicted Gaussians', height=self.height, value=None, clear_color=[0, 0, 0, 0], ) # Layout with gr.Row(): with gr.Column(scale=2): with gr.Row(): run_reset = gr.Button("Reset") with gr.Row(): with gr.Column(): run_xminus = gr.Button("x-") with gr.Column(): run_xplus = gr.Button("x+") with gr.Row(): with gr.Column(): run_yminus = gr.Button("y-") with gr.Column(): run_yplus = gr.Button("y+") with gr.Row(): with gr.Column(): run_zminus = gr.Button("z-") with gr.Column(): run_zplus = gr.Button("z+") with gr.Column(scale=2): _ = gr.Button(visible=False) # empty placeholder # Set up callbacks run_reset.click(self.reset, inputs=[], outputs=[form_video, form_3dgs_pred]) run_xplus.click(self.on_click_run_xplus, inputs=[], outputs=[form_video, form_3dgs_pred]) run_xminus.click(self.on_click_run_xminus, inputs=[], outputs=[form_video, form_3dgs_pred]) run_yplus.click(self.on_click_run_yplus, inputs=[], outputs=[form_video, form_3dgs_pred]) run_yminus.click(self.on_click_run_yminus, inputs=[], outputs=[form_video, form_3dgs_pred]) run_zplus.click(self.on_click_run_zplus, inputs=[], outputs=[form_video, form_3dgs_pred]) run_zminus.click(self.on_click_run_zminus, inputs=[], outputs=[form_video, form_3dgs_pred]) app.launch(share=share) if __name__ == '__main__': visualizer = DynamicsVisualizer() visualizer.launch(share=True)