pgnd / app.py
kaifz's picture
update
262dac2
raw
history blame
47.3 kB
import gradio as gr
import importlib
import sys
import site
from PIL import Image
from pathlib import Path
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm, trange
import random
import math
import hydra
import numpy as np
import torch
import torch.nn as nn
print(torch.__version__)
print(torch.version.cuda)
# import torch.backends.cudnn
import warp as wp
import glob
from torch.utils.data import DataLoader
import os
import subprocess
import time
import cv2
import copy
import kornia
import yaml
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
import spaces
from spaces import zero
zero.startup()
def install_cuda_toolkit():
# CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
# CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run"
CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
os.environ["CUDA_HOME"] = "/usr/local/cuda"
os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
os.environ["CUDA_HOME"],
"" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
)
# Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
install_cuda_toolkit()
gs_path = Path(__file__).parent / "src/third-party/diff-gaussian-rasterization-w-depth"
subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", str(gs_path)])
site.main() # re-processes every *.pth in site-packages
importlib.invalidate_caches()
diff_gaussian_rasterization = importlib.import_module("diff_gaussian_rasterization")
os.system('pip install dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html')
# os.system('conda install conda-forge::ffmpeg')
import sys
sys.path.insert(0, str(Path(__file__).parent / "src"))
sys.path.append(str(Path(__file__).parent / "src" / "experiments"))
from pgnd.sim import Friction, CacheDiffSimWithFrictionBatch, StaticsBatch, CollidersBatch
from pgnd.material import PGNDModel
from pgnd.utils import Logger, get_root, mkdir
from pgnd.ffmpeg import make_video
from real_world.utils.render_utils import interpolate_motions
from real_world.gs.helpers import setup_camera
from real_world.gs.convert import save_to_splat, read_splat
from diff_gaussian_rasterization import GaussianRasterizer
from diff_gaussian_rasterization import GaussianRasterizationSettings as Camera
root = Path(__file__).parent / "src" / "experiments"
def quat2mat(quat):
return kornia.geometry.conversions.quaternion_to_rotation_matrix(quat)
def mat2quat(mat):
return kornia.geometry.conversions.rotation_matrix_to_quaternion(mat)
def fps(x, enabled, n, device, random_start=False):
from dgl.geometry import farthest_point_sampler
assert torch.diff(enabled * 1.0).sum() in [0.0, -1.0]
start_idx = random.randint(0, enabled.sum() - 1) if random_start else 0
fps_idx = farthest_point_sampler(x[enabled][None], n, start_idx=start_idx)[0]
fps_idx = fps_idx.to(x.device)
return fps_idx
class DynamicsVisualizer:
def __init__(self):
self.width = 640
self.height = 480
best_models = {
'cloth': ['cloth', 'train', 100000, [610, 650]],
'rope': ['rope', 'train', 100000, [651, 691]],
'paperbag': ['paperbag', 'train', 100000, [200, 220]],
'sloth': ['sloth', 'train', 100000, [113, 133]],
'box': ['box', 'train', 100000, [306, 323]],
'bread': ['bread', 'train', 100000, [143, 163]],
}
task_name = 'rope'
with open(root / f'log/{best_models[task_name][0]}/{best_models[task_name][1]}/hydra.yaml', 'r') as f:
config = yaml.load(f, Loader=yaml.CLoader)
cfg = OmegaConf.create(config)
cfg.iteration = best_models[task_name][2]
cfg.start_episode = best_models[task_name][3][0]
cfg.end_episode = best_models[task_name][3][1]
cfg.sim.num_steps = 1000
cfg.sim.gripper_forcing = False
cfg.sim.uniform = True
cfg.sim.use_pv = False
device = torch.device('cuda')
self.cfg = cfg
self.device = device
self.k_rel = 8 # knn for relations
self.k_wgt = 16 # knn for weights
self.with_bg = True
self.render_gripper = True
self.verbose = False
self.dt_base = cfg.sim.dt
self.high_freq_pred = True
seed = cfg.seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# torch.autograd.set_detect_anomaly(True)
# torch.backends.cudnn.benchmark = True
self.clear()
def clear(self, clear_params=True):
self.metadata = {}
self.config = {}
if clear_params:
self.params = None
self.state = {
# object
'x': None,
'v': None,
'x_his': None,
'v_his': None,
'x_pred': None,
'v_pred': None,
'clip_bound': None,
'enabled': None,
# robot
'prev_key_pos': None,
'prev_key_pos_timestamp': None,
'sub_pos': None, # filling in between key positions
'sub_pos_timestamps': None,
'gripper_radius': None,
}
self.preprocess_metadata = None
self.table_params = None
self.gripper_params = None
self.sim = None
self.statics = None
self.colliders = None
self.material = None
self.friction = None
def load_scaniverse(self, data_path):
### load splat params
params_obj = read_splat(data_path / 'object.splat')
params_table = read_splat(data_path / 'table.splat')
params_robot = read_splat(data_path / 'gripper.splat')
pts, colors, scales, quats, opacities = params_obj
self.params = {
'means3D': torch.from_numpy(pts).to(torch.float32).to(self.device),
'rgb_colors': torch.from_numpy(colors).to(torch.float32).to(self.device),
'log_scales': torch.log(torch.from_numpy(scales).to(torch.float32).to(self.device)),
'unnorm_rotations': torch.from_numpy(quats).to(torch.float32).to(self.device),
'logit_opacities': torch.logit(torch.from_numpy(opacities).to(torch.float32).to(self.device))
}
t_pts, t_colors, t_scales, t_quats, t_opacities = params_table
t_pts = torch.tensor(t_pts).to(torch.float32).to(self.device)
t_colors = torch.tensor(t_colors).to(torch.float32).to(self.device)
t_scales = torch.tensor(t_scales).to(torch.float32).to(self.device)
t_quats = torch.tensor(t_quats).to(torch.float32).to(self.device)
t_opacities = torch.tensor(t_opacities).to(torch.float32).to(self.device)
g_pts, g_colors, g_scales, g_quats, g_opacities = params_robot
g_pts = torch.tensor(g_pts).to(torch.float32).to(self.device)
g_colors = torch.tensor(g_colors).to(torch.float32).to(self.device)
g_scales = torch.tensor(g_scales).to(torch.float32).to(self.device)
g_quats = torch.tensor(g_quats).to(torch.float32).to(self.device)
g_opacities = torch.tensor(g_opacities).to(torch.float32).to(self.device)
self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities # data frame
self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities # data frame
n_particles = self.cfg.sim.n_particles
self.state['clip_bound'] = torch.tensor([self.cfg.model.clip_bound], dtype=torch.float32)
self.state['enabled'] = torch.ones(n_particles, dtype=torch.bool)
### load preprocess metadata
cfg = self.cfg
dx = cfg.sim.num_grids[-1]
p_x = torch.tensor(pts).to(torch.float32).to(self.device)
R = torch.tensor(
[[1, 0, 0],
[0, 0, -1],
[0, 1, 0]]
).to(p_x.device).to(p_x.dtype)
p_x_rotated = p_x @ R.T
scale = 1.0
p_x_rotated_scaled = p_x_rotated * scale
global_translation = torch.tensor([
0.5 - p_x_rotated_scaled[:, 0].mean(),
dx * (cfg.model.clip_bound + 0.5) - p_x_rotated_scaled[:, 1].min(),
0.5 - p_x_rotated_scaled[:, 2].mean(),
], dtype=p_x_rotated_scaled.dtype, device=p_x_rotated_scaled.device)
R_viewer = torch.tensor(
[[1, 0, 0],
[0, 0, -1],
[0, 1, 0]]
).to(p_x.device).to(p_x.dtype)
t_viewer = torch.tensor([0, 0, 0]).to(p_x.device).to(p_x.dtype)
self.preprocess_metadata = {
'R': R,
'R_viewer': R_viewer,
't_viewer': t_viewer,
'scale': scale,
'global_translation': global_translation,
}
### load eef
grippers = np.loadtxt(data_path / 'eef_xyz.txt')[None]
assert grippers.shape == (1, 3)
if grippers is not None:
grippers = torch.tensor(grippers).to(self.device).to(torch.float32)
# transform
# data frame to model frame
R = self.preprocess_metadata['R']
scale = self.preprocess_metadata['scale']
global_translation = self.preprocess_metadata['global_translation']
grippers[:, :3] = grippers[:, :3] @ R.T
grippers[:, :3] = grippers[:, :3] * scale
grippers[:, :3] += global_translation
assert grippers.shape[0] == 1
self.state['prev_key_pos'] = grippers[:, :3] # (1, 3)
# self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32)
self.state['gripper_radius'] = cfg.model.gripper_radius
def load_params(self, params_path, remove_low_opa=True, remove_black=False):
pts, colors, scales, quats, opacities = read_splat(params_path)
if remove_low_opa:
low_opa_idx = opacities[:, 0] < 0.1
pts = pts[~low_opa_idx]
colors = colors[~low_opa_idx]
quats = quats[~low_opa_idx]
opacities = opacities[~low_opa_idx]
scales = scales[~low_opa_idx]
if remove_black:
low_color_idx = colors.sum(axis=-1) < 0.5
pts = pts[~low_color_idx]
colors = colors[~low_color_idx]
quats = quats[~low_color_idx]
opacities = opacities[~low_color_idx]
scales = scales[~low_color_idx]
self.params = {
'means3D': torch.from_numpy(pts).to(torch.float32).to(self.device),
'rgb_colors': torch.from_numpy(colors).to(torch.float32).to(self.device),
'log_scales': torch.log(torch.from_numpy(scales).to(torch.float32).to(self.device)),
'unnorm_rotations': torch.from_numpy(quats).to(torch.float32).to(self.device),
'logit_opacities': torch.logit(torch.from_numpy(opacities).to(torch.float32).to(self.device))
}
table_splat = root / 'log/gs/ckpts/table.splat'
sphere_splat = root / 'log/gs/ckpts/sphere.splat'
gripper_splat = root / 'log/gs/ckpts/gripper.splat' # gripper_new.splat
table_params = read_splat(table_splat) # numpy
## add table and gripper
# add table
t_pts, t_colors, t_scales, t_quats, t_opacities = table_params
t_pts = torch.tensor(t_pts).to(torch.float32).to(self.device)
t_colors = torch.tensor(t_colors).to(torch.float32).to(self.device)
t_scales = torch.tensor(t_scales).to(torch.float32).to(self.device)
t_quats = torch.tensor(t_quats).to(torch.float32).to(self.device)
t_opacities = torch.tensor(t_opacities).to(torch.float32).to(self.device)
# add table pos
t_pts = t_pts + torch.tensor([0, 0, 0.02]).to(torch.float32).to(self.device)
# add gripper
gripper_params = read_splat(gripper_splat) # numpy
g_pts, g_colors, g_scales, g_quats, g_opacities = gripper_params
g_pts = torch.tensor(g_pts).to(torch.float32).to(self.device)
g_colors = torch.tensor(g_colors).to(torch.float32).to(self.device)
g_scales = torch.tensor(g_scales).to(torch.float32).to(self.device)
g_quats = torch.tensor(g_quats).to(torch.float32).to(self.device)
g_opacities = torch.tensor(g_opacities).to(torch.float32).to(self.device)
# we do not do the gripper translation now because this will center the gripper in the data frame but not the viewer frame
self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities # data frame
self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities # data frame
# load other info
n_particles = self.cfg.sim.n_particles
self.state['clip_bound'] = torch.tensor([self.cfg.model.clip_bound], dtype=torch.float32)
self.state['enabled'] = torch.ones(n_particles, dtype=torch.bool)
def set_camera(self, w, h, intr, w2c=None, R=None, t=None, near=0.01, far=100.0):
if w2c is None:
assert R is not None and t is not None
w2c = Rt_to_w2c(R, t)
self.metadata = {
'w': w,
'h': h,
'k': intr,
'w2c': w2c,
}
self.config = {'near': near, 'far': far}
def load_eef(self, grippers=None, eef_t=None):
assert self.state['prev_key_pos'] is None
if grippers is not None:
grippers = torch.tensor(grippers).to(self.device).to(torch.float32)
eef_t = torch.tensor(eef_t).to(self.device).to(torch.float32)
grippers[:, :3] = grippers[:, :3] + eef_t
# transform
# data frame to model frame
R = self.preprocess_metadata['R']
scale = self.preprocess_metadata['scale']
global_translation = self.preprocess_metadata['global_translation']
grippers[:, :3] = grippers[:, :3] @ R.T
grippers[:, :3] = grippers[:, :3] * scale
grippers[:, :3] += global_translation
assert grippers.shape[0] == 1
self.state['prev_key_pos'] = grippers[:, :3] # (1, 3)
# self.state['prev_key_pos_timestamp'] = torch.zeros(1).to(self.device).to(torch.float32) + 0.001
self.state['gripper_radius'] = self.cfg.model.gripper_radius
def load_preprocess_metadata(self, p_x_orig):
cfg = self.cfg
dx = cfg.sim.num_grids[-1]
p_x_orig = p_x_orig.to(self.device)
R = torch.tensor(
[[1, 0, 0],
[0, 0, -1],
[0, 1, 0]]
).to(p_x_orig.device).to(p_x_orig.dtype)
p_x_orig_rotated = torch.einsum('nij,jk->nik', p_x_orig, R.T)
scale = 1.0
p_x_orig_rotated_scaled = p_x_orig_rotated * scale
global_translation = torch.tensor([
0.5 - p_x_orig_rotated_scaled[:, :, 0].mean(),
dx * (cfg.model.clip_bound + 0.5) - p_x_orig_rotated_scaled[:, :, 1].min(),
0.5 - p_x_orig_rotated_scaled[:, :, 2].mean(),
], dtype=p_x_orig_rotated_scaled.dtype, device=p_x_orig_rotated_scaled.device)
R_viewer = torch.tensor(
[[1, 0, 0],
[0, 0, -1],
[0, 1, 0]]
).to(p_x_orig.device).to(p_x_orig.dtype)
t_viewer = torch.tensor([0, 0, 0]).to(p_x_orig.device).to(p_x_orig.dtype)
self.preprocess_metadata = {
'R': R,
'R_viewer': R_viewer,
't_viewer': t_viewer,
'scale': scale,
'global_translation': global_translation,
}
@torch.no_grad
def render(self, render_data, cam_id, bg=[0.7, 0.7, 0.7]):
render_data = {k: v.to(self.device) for k, v in render_data.items()}
w, h = self.metadata['w'], self.metadata['h']
k, w2c = self.metadata['k'], self.metadata['w2c']
cam = setup_camera(w, h, k, w2c, self.config['near'], self.config['far'], bg)
im, _, depth, = GaussianRasterizer(raster_settings=cam)(**render_data)
return im, depth
def knn_relations(self, bones):
k = self.k_rel
knn = NearestNeighbors(n_neighbors=k+1, algorithm='kd_tree').fit(bones.detach().cpu().numpy())
_, indices = knn.kneighbors(bones.detach().cpu().numpy()) # (N, k)
indices = indices[:, 1:] # exclude self
return indices
def knn_weights_brute(self, bones, pts):
k = self.k_wgt
dist = torch.norm(pts[:, None] - bones, dim=-1) # (n_pts, n_bones)
_, indices = torch.topk(dist, k, dim=-1, largest=False)
bones_selected = bones[indices] # (N, k, 3)
dist = torch.norm(bones_selected - pts[:, None], dim=-1) # (N, k)
weights = 1 / (dist + 1e-6)
weights = weights / weights.sum(dim=-1, keepdim=True) # (N, k)
weights_all = torch.zeros((pts.shape[0], bones.shape[0]), device=pts.device)
weights_all[torch.arange(pts.shape[0])[:, None], indices] = weights
return weights_all
def update_camera(self, k, w2c, w=None, h=None, near=0.01, far=100.0):
self.metadata['k'] = k
self.metadata['w2c'] = w2c
if w is not None:
self.metadata['w'] = w
if h is not None:
self.metadata['h'] = h
self.config['near'] = near
self.config['far'] = far
def init_model(self, batch_size, num_steps, num_particles, ckpt_path=None):
self.cfg.sim.num_steps = num_steps
cfg = self.cfg
sim = CacheDiffSimWithFrictionBatch(cfg, num_steps, batch_size, self.wp_device, requires_grad=True)
statics = StaticsBatch()
statics.init(shape=(batch_size, num_particles), device=self.wp_device)
statics.update_clip_bound(self.state['clip_bound'])
statics.update_enabled(self.state['enabled'][None])
colliders = CollidersBatch()
colliders.init(shape=(batch_size, cfg.sim.num_grippers), device=self.wp_device)
self.sim = sim
self.statics = statics
self.colliders = colliders
# load ckpt
ckpt_path = root / 'log/rope/train/ckpt/100000.pt'
ckpt = torch.load(ckpt_path, map_location=self.torch_device)
material: nn.Module = PGNDModel(cfg)
material.to(self.torch_device)
material.load_state_dict(ckpt['material'])
material.requires_grad_(False)
material.eval()
if 'friction' in ckpt:
friction = ckpt['friction']['mu'].reshape(-1, 1)
else:
friction = torch.tensor(cfg.model.friction.value, device=self.torch_device).reshape(-1, 1)
self.material = material
self.friction = friction
def reload_model(self, num_steps): # only change num_steps
self.cfg.sim.num_steps = num_steps
sim = CacheDiffSimWithFrictionBatch(self.cfg, num_steps, 1, self.wp_device, requires_grad=True)
self.sim = sim
@torch.no_grad
def step(self):
cfg = self.cfg
batch_size = 1
num_steps = 1
num_particles = cfg.sim.n_particles
# update state by previous prediction
self.state['x_his'] = torch.cat([self.state['x_his'][1:], self.state['x'][None]], dim=0)
self.state['v_his'] = torch.cat([self.state['v_his'][1:], self.state['v'][None]], dim=0)
self.state['x'] = self.state['x_pred'].clone()
self.state['v'] = self.state['v_pred'].clone()
eef_xyz_key = self.state['prev_key_pos'] # (1, 3), model frame
eef_xyz_sub = self.state['sub_pos'] # (T, 1, 3), model frame
if eef_xyz_sub is None:
return
# eef_xyz_key_timestamp = self.state['prev_key_pos_timestamp']
# eef_xyz_sub_timestamps = self.state['sub_pos_timestamps']
# assert eef_xyz_key_timestamp.item() > 0
# delta_t = (eef_xyz_sub_timestamps[-1] - eef_xyz_key_timestamp).item()
# if (not self.high_freq_pred) and delta_t < self.dt_base * 0.9:
# return
# cfg.sim.dt = delta_t
eef_xyz_key_next = eef_xyz_sub[-1] # (1, 3), model frame
eef_v = (eef_xyz_key_next - eef_xyz_key) / cfg.sim.dt
if self.verbose:
print('delta_t:', np.round(cfg.sim.dt, 4))
print('eef_xyz_key_next:', eef_xyz_key_next.cpu().numpy().tolist())
print('eef_xyz_key:', eef_xyz_key.cpu().numpy().tolist())
print('v:', eef_v.cpu().numpy().tolist())
# load model, sim, statics, colliders
self.reload_model(num_steps)
# initialize colliders
if cfg.sim.num_grippers > 0:
grippers = torch.zeros((batch_size, cfg.sim.num_grippers, 15), device=self.torch_device)
eef_quat = torch.tensor([1, 0, 0, 0], dtype=torch.float32, device=self.torch_device).repeat(batch_size, cfg.sim.num_grippers, 1) # (B, G, 4)
eef_quat_vel = torch.zeros((batch_size, cfg.sim.num_grippers, 3), dtype=torch.float32, device=self.torch_device)
eef_gripper = torch.zeros((batch_size, cfg.sim.num_grippers), dtype=torch.float32, device=self.torch_device)
grippers[:, :, :3] = eef_xyz_key
grippers[:, :, 3:6] = eef_v
grippers[:, :, 6:10] = eef_quat
grippers[:, :, 10:13] = eef_quat_vel
grippers[:, :, 13] = cfg.model.gripper_radius
grippers[:, :, 14] = eef_gripper
self.colliders.initialize_grippers(grippers)
x = self.state['x'].clone()[None].repeat(batch_size, 1, 1)
v = self.state['v'].clone()[None].repeat(batch_size, 1, 1)
x_his = self.state['x_his'].permute(1, 0, 2).clone()
assert x_his.shape[0] == num_particles
x_his = x_his.reshape(num_particles, -1)[None].repeat(batch_size, 1, 1)
v_his = self.state['v_his'].permute(1, 0, 2).clone()
assert v_his.shape[0] == num_particles
v_his = v_his.reshape(num_particles, -1)[None].repeat(batch_size, 1, 1)
enabled = self.state['enabled'].clone().to(self.torch_device)[None].repeat(batch_size, 1)
for t in range(num_steps):
x_in = x.clone()
pred = self.material(x, v, x_his, v_his, enabled)
# x_his = torch.cat([x_his.reshape(batch_size, num_particles, -1, 3)[:, :, 1:], x[:, :, None].detach()], dim=2)
# v_his = torch.cat([v_his.reshape(batch_size, num_particles, -1, 3)[:, :, 1:], v[:, :, None].detach()], dim=2)
# x_his = x_his.reshape(batch_size, num_particles, -1)
# v_his = v_his.reshape(batch_size, num_particles, -1)
x, v = self.sim(self.statics, self.colliders, t, x, v, self.friction, pred)
# calculate new x_pred, v_pred, eef_xyz_key and eef_xyz_sub
x_pred = x[0].clone()
v_pred = v[0].clone()
self.state['x_pred'] = x_pred
self.state['v_pred'] = v_pred
# self.state['x_his'] = x_his[0].reshape(num_particles, self.cfg.sim.n_history, 3).permute(1, 0, 2).clone()
# self.state['v_his'] = v_his[0].reshape(num_particles, self.cfg.sim.n_history, 3).permute(1, 0, 2).clone()
self.state['prev_key_pos'] = eef_xyz_key_next
# self.state['prev_key_pos_timestamp'] = eef_xyz_sub_timestamps[-1]
self.state['sub_pos'] = None
# self.state['sub_pos_timestamps'] = None
def preprocess_x(self, p_x): # viewer frame to model frame (not data frame)
R = self.preprocess_metadata['R']
R_viewer = self.preprocess_metadata['R_viewer']
t_viewer = self.preprocess_metadata['t_viewer']
scale = self.preprocess_metadata['scale']
global_translation = self.preprocess_metadata['global_translation']
# viewer frame to model frame
p_x = (p_x - t_viewer) @ R_viewer
# model frame to data frame
# p_x -= global_translation
# p_x = p_x / scale
# p_x = p_x @ torch.linalg.inv(R).T
return p_x
def preprocess_gripper(self, grippers): # viewer frame to model frame (not data frame)
R = self.preprocess_metadata['R']
R_viewer = self.preprocess_metadata['R_viewer']
t_viewer = self.preprocess_metadata['t_viewer']
scale = self.preprocess_metadata['scale']
global_translation = self.preprocess_metadata['global_translation']
# viewer frame to model frame
grippers[:, :3] = grippers[:, :3] @ R_viewer
return grippers
def inverse_preprocess_x(self, p_x): # model frame (not data frame) to viewer frame
R = self.preprocess_metadata['R']
R_viewer = self.preprocess_metadata['R_viewer']
t_viewer = self.preprocess_metadata['t_viewer']
scale = self.preprocess_metadata['scale']
global_translation = self.preprocess_metadata['global_translation']
# model frame to viewer frame
p_x = p_x @ R_viewer.T + t_viewer
return p_x
def inverse_preprocess_gripper(self, grippers): # model frame (not data frame) to viewer frame
R = self.preprocess_metadata['R']
R_viewer = self.preprocess_metadata['R_viewer']
t_viewer = self.preprocess_metadata['t_viewer']
scale = self.preprocess_metadata['scale']
global_translation = self.preprocess_metadata['global_translation']
# model frame to viewer frame
grippers[:, :3] = grippers[:, :3] @ R_viewer.T + t_viewer
return grippers
def rotate(self, params, rot_mat):
scale = np.linalg.norm(rot_mat, axis=1, keepdims=True)
params = {
'means3D': pts,
'rgb_colors': params['rgb_colors'],
'log_scales': params['log_scales'],
'unnorm_rotations': quats,
'logit_opacities': params['logit_opacities'],
}
return params
def preprocess_gs(self, params):
if isinstance(params, dict):
xyz = params['means3D']
rgb = params['rgb_colors']
quat = torch.nn.functional.normalize(params['unnorm_rotations'])
opa = torch.sigmoid(params['logit_opacities'])
scales = torch.exp(params['log_scales'])
else:
assert isinstance(params, tuple)
xyz, rgb, quat, opa, scales = params
quat = torch.nn.functional.normalize(quat, dim=-1)
# transform
R = self.preprocess_metadata['R']
R_viewer = self.preprocess_metadata['R_viewer']
scale = self.preprocess_metadata['scale']
global_translation = self.preprocess_metadata['global_translation']
mat = quat2mat(quat)
mat = R @ mat
xyz = xyz @ R.T
xyz = xyz * scale
xyz += global_translation
quat = mat2quat(mat)
scales = scales * scale
# viewer-specific transform (flip y and z)
# model frame to viewer frame
xyz = xyz @ R_viewer.T
quat = mat2quat(R_viewer @ quat2mat(quat))
t_viewer = -xyz.mean(dim=0)
t_viewer[2] = 0
xyz += t_viewer
print('Overwriting t_viewer to be the planar mean of the object')
self.preprocess_metadata['t_viewer'] = t_viewer
if isinstance(params, dict):
params['means3D'] = xyz
params['rgb_colors'] = rgb
params['unnorm_rotations'] = quat
params['logit_opacities'] = opa
params['log_scales'] = torch.log(scales)
else:
params = xyz, rgb, quat, opa, scales
return params
def preprocess_bg_gs(self):
t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params
g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params
# identify tip first
g_pts_tip_z = g_pts[:, 2].max()
g_pts_tip_mask = (g_pts[:, 2] > g_pts_tip_z - 0.04) & (g_pts[:, 2] < g_pts_tip_z)
R = self.preprocess_metadata['R']
R_viewer = self.preprocess_metadata['R_viewer']
t_viewer = self.preprocess_metadata['t_viewer']
scale = self.preprocess_metadata['scale']
global_translation = self.preprocess_metadata['global_translation']
t_mat = quat2mat(t_quats)
t_mat = R @ t_mat
t_pts = t_pts @ R.T
t_pts = t_pts * scale
t_pts += global_translation
t_quats = mat2quat(t_mat)
t_scales = t_scales * scale
t_pts = t_pts @ R_viewer.T
t_quats = mat2quat(R_viewer @ quat2mat(t_quats))
t_pts += t_viewer
g_mat = quat2mat(g_quats)
g_mat = R @ g_mat
g_pts = g_pts @ R.T
g_pts = g_pts * scale
g_pts += global_translation
g_quats = mat2quat(g_mat)
g_scales = g_scales * scale
g_pts = g_pts @ R_viewer.T
g_quats = mat2quat(R_viewer @ quat2mat(g_quats))
g_pts += t_viewer
# TODO: center gripper in the viewer frame
g_pts_tip = g_pts[g_pts_tip_mask]
g_pts_tip_mean_xy = g_pts_tip[:, :2].mean(dim=0)
g_pts_translation = torch.tensor([-g_pts_tip_mean_xy[0], -g_pts_tip_mean_xy[1], -0.23]).to(torch.float32).to(self.device)
g_pts = g_pts + g_pts_translation
self.table_params = t_pts, t_colors, t_scales, t_quats, t_opacities
self.gripper_params = g_pts, g_colors, g_scales, g_quats, g_opacities
def update_rendervar(self, rendervar):
p_x = self.state['x']
p_x_viewer = self.inverse_preprocess_x(p_x)
p_x_pred = self.state['x_pred']
p_x_pred_viewer = self.inverse_preprocess_x(p_x_pred)
xyz = rendervar['means3D']
rgb = rendervar['colors_precomp']
quat = rendervar['rotations']
opa = rendervar['opacities']
scales = rendervar['scales']
relations = self.knn_relations(p_x_viewer)
weights = self.knn_weights_brute(p_x_viewer, xyz)
xyz, quat, _ = interpolate_motions(
bones=p_x_viewer,
motions=p_x_pred_viewer - p_x_viewer,
relations=relations,
weights=weights,
xyz=xyz,
quat=quat,
)
# normalize
quat = torch.nn.functional.normalize(quat, dim=-1)
rendervar = {
'means3D': xyz,
'colors_precomp': rgb,
'rotations': quat,
'opacities': opa,
'scales': scales,
'means2D': torch.zeros_like(xyz),
}
if self.with_bg:
t_pts, t_colors, t_scales, t_quats, t_opacities = self.table_params
# merge
xyz = torch.cat([xyz, t_pts], dim=0)
rgb = torch.cat([rgb, t_colors], dim=0)
quat = torch.cat([quat, t_quats], dim=0)
opa = torch.cat([opa, t_opacities], dim=0)
scales = torch.cat([scales, t_scales], dim=0)
if self.render_gripper:
g_pts, g_colors, g_scales, g_quats, g_opacities = self.gripper_params
# add gripper pos
g_pts = g_pts + self.inverse_preprocess_gripper(self.state['prev_key_pos'][None].clone())[0]
# merge
xyz = torch.cat([xyz, g_pts], dim=0)
rgb = torch.cat([rgb, g_colors], dim=0)
quat = torch.cat([quat, g_quats], dim=0)
opa = torch.cat([opa, g_opacities], dim=0)
scales = torch.cat([scales, g_scales], dim=0)
# normalize
quat = torch.nn.functional.normalize(quat, dim=-1)
rendervar_full = {
'means3D': xyz,
'colors_precomp': rgb,
'rotations': quat,
'opacities': opa,
'scales': scales,
'means2D': torch.zeros_like(xyz),
}
else:
rendervar_full = rendervar
return rendervar, rendervar_full
def reset_state(self, params, visualize_image=False, init=False):
xyz_0 = params['means3D']
rgb_0 = params['rgb_colors']
quat_0 = torch.nn.functional.normalize(params['unnorm_rotations'])
opa_0 = torch.sigmoid(params['logit_opacities'])
scales_0 = torch.exp(params['log_scales'])
rendervar_init = {
'means3D': xyz_0,
'colors_precomp': rgb_0,
'rotations': quat_0,
'opacities': opa_0,
'scales': scales_0,
'means2D': torch.zeros_like(xyz_0),
} # before preprocess
w = self.width
h = self.height
center = (0, 0, 0.1)
distance = 0.7
elevation = 20
azimuth = 180.0
target = np.array(center)
theta = 90 + azimuth
z = distance * math.sin(math.radians(elevation))
y = math.cos(math.radians(theta)) * distance * math.cos(math.radians(elevation))
x = math.sin(math.radians(theta)) * distance * math.cos(math.radians(elevation))
origin = target + np.array([x, y, z])
look_at = target - origin
look_at /= np.linalg.norm(look_at)
up = np.array([0.0, 0.0, 1.0])
right = np.cross(look_at, up)
right /= np.linalg.norm(right)
up = np.cross(right, look_at)
w2c = np.eye(4)
w2c[:3, 0] = right
w2c[:3, 1] = -up
w2c[:3, 2] = look_at
w2c[:3, 3] = origin
w2c = np.linalg.inv(w2c)
k = np.array(
[[w / 2 * 1.0, 0., w / 2],
[0., w / 2 * 1.0, h / 2],
[0., 0., 1.]],
)
self.update_camera(k, w2c, w, h)
n_particles = self.cfg.sim.n_particles
downsample_indices = fps(xyz_0, torch.ones_like(xyz_0[:, 0]).to(torch.bool), n_particles, self.torch_device)
p_x_viewer = xyz_0[downsample_indices]
p_x = self.preprocess_x(p_x_viewer)
self.state['x'] = p_x
self.state['v'] = torch.zeros_like(p_x)
self.state['x_his'] = p_x[None].repeat(self.cfg.sim.n_history, 1, 1)
self.state['v_his'] = torch.zeros_like(p_x[None].repeat(self.cfg.sim.n_history, 1, 1))
self.state['x_pred'] = p_x
self.state['v_pred'] = torch.zeros_like(p_x)
rendervar_init, rendervar_init_full = self.update_rendervar(rendervar_init)
im, depth = self.render(rendervar_init_full, 0, bg=[0.0, 0.0, 0.0])
im_vis = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8)
return rendervar_init
@spaces.GPU
def reset(self):
wp.init()
gpus = [int(gpu) for gpu in self.cfg.gpus]
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus]
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus]
device_count = len(torch_devices)
assert device_count == 1
self.wp_device = wp_devices[0]
self.torch_device = torch_devices[0]
in_dir = root / 'log/gs/ckpts/rope_scene_1'
batch_size = 1
num_steps = 1
num_particles = self.cfg.sim.n_particles
self.load_scaniverse(in_dir)
self.init_model(batch_size, num_steps, num_particles, ckpt_path=None)
params = self.preprocess_gs(self.params)
if self.with_bg:
self.preprocess_bg_gs()
rendervar = self.reset_state(params, visualize_image=False, init=True)
rendervar, rendervar_full = self.update_rendervar(rendervar)
self.rendervar = rendervar
im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0])
im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy()
cv2.imwrite(str(root / 'log/temp_init/0000.png'), cv2.cvtColor(im_show, cv2.COLOR_RGB2BGR))
make_video(root / 'log/temp_init', root / f'log/gs/temp/form_video_init.mp4', '%04d.png', 1)
gs_pred = save_to_splat(
rendervar_full['means3D'].cpu().numpy(),
rendervar_full['colors_precomp'].cpu().numpy(),
rendervar_full['scales'].cpu().numpy(),
rendervar_full['rotations'].cpu().numpy(),
rendervar_full['opacities'].cpu().numpy(),
root / 'log/gs/temp/gs_pred.splat',
rot_rev=True,
)
form_video = gr.Video(
label='Predicted video',
value=root / f'log/gs/temp/form_video.mp4',
format='mp4',
width=self.width,
height=self.height,
)
form_3dgs_pred = gr.Model3D(
label='Predicted Gaussian Splats',
height=self.height,
value=root / 'log/gs/temp/gs_pred.splat',
clear_color=[0, 0, 0, 0],
)
return form_video, form_3dgs_pred
def run_command(self, unit_command):
os.system('rm -rf ' + str(root / 'log/temp/*'))
# im_list = []
for i in range(15):
dt = 0.1 # 100ms
command = torch.tensor([unit_command]).to(self.device).to(torch.float32) # 5cm/s
command = self.preprocess_gripper(command)
# command_timestamp = torch.tensor([self.state['prev_key_pos_timestamp'] + (i+1) * dt]).to(self.device).to(torch.float32)
# print(command_timestamp)
if self.verbose:
print('command:', command.cpu().numpy().tolist())
assert self.state['sub_pos'] is None
if self.state['sub_pos'] is None:
eef_xyz_latest = self.state['prev_key_pos']
# eef_xyz_timestamp_latest = self.state['prev_key_pos_timestamp']
else:
eef_xyz_latest = self.state['sub_pos'][-1] # (1, 3), model frame
# eef_xyz_timestamp_latest = self.state['sub_pos_timestamps'][-1].item()
eef_xyz_updated = eef_xyz_latest + command * dt * 0.01 # cm to m
if self.state['sub_pos'] is None:
self.state['sub_pos'] = eef_xyz_updated[None]
# self.state['sub_pos_timestamps'] = command_timestamp
else:
self.state['sub_pos'] = torch.cat([self.state['sub_pos'], eef_xyz_updated[None]], dim=0)
# self.state['sub_pos_timestamps'] = torch.cat([self.state['sub_pos_timestamps'], command_timestamp], dim=0)
# if self.state['sub_pos'] is None:
# eef_xyz = self.state['prev_key_pos']
# else:
# eef_xyz = self.state['sub_pos'][-1] # (1, 3), model frame
# if self.verbose:
# print(eef_xyz.cpu().numpy().tolist(), end=' ')
self.step()
rendervar, rendervar_full = self.update_rendervar(self.rendervar)
self.rendervar = rendervar
im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0])
im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy()
# im_list.append(im_show)
cv2.imwrite(str(root / f'log/temp/{i:04}.png'), cv2.cvtColor(im_show, cv2.COLOR_RGB2BGR))
# self.state['prev_key_pos_timestamp'] = self.state['prev_key_pos_timestamp'] + 20 * dt
self.state['v'] *= 0.0
self.state['x'] = self.state['x_pred'].clone()
self.state['x_his'] = self.state['x'][None].repeat(self.cfg.sim.n_history, 1, 1)
self.state['v_his'] *= 0.0
self.state['v_pred'] *= 0.0
make_video(root / 'log/temp', root / f'log/gs/temp/form_video.mp4', '%04d.png', 5)
form_video = gr.Video(
label='Predicted video',
value=root / f'log/gs/temp/form_video.mp4',
format='mp4',
width=self.width,
height=self.height,
)
im, depth = self.render(rendervar_full, 0, bg=[0.0, 0.0, 0.0])
im_show = (im.permute(1, 2, 0) * 255.0).cpu().numpy().astype(np.uint8).copy()
gs_pred = save_to_splat(
rendervar_full['means3D'].cpu().numpy(),
rendervar_full['colors_precomp'].cpu().numpy(),
rendervar_full['scales'].cpu().numpy(),
rendervar_full['rotations'].cpu().numpy(),
rendervar_full['opacities'].cpu().numpy(),
root / 'log/gs/temp/gs_pred.splat',
rot_rev=True,
)
form_3dgs_pred = gr.Model3D(
label='Predicted Gaussian Splats',
height=self.height,
value=root / 'log/gs/temp/gs_pred.splat',
clear_color=[0, 0, 0, 0],
)
return form_video, form_3dgs_pred
@spaces.GPU
def on_click_run_xplus(self):
wp.init()
gpus = [int(gpu) for gpu in self.cfg.gpus]
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus]
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus]
device_count = len(torch_devices)
assert device_count == 1
self.wp_device = wp_devices[0]
self.torch_device = torch_devices[0]
return self.run_command([5.0, 0, 0])
@spaces.GPU
def on_click_run_xminus(self):
wp.init()
gpus = [int(gpu) for gpu in self.cfg.gpus]
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus]
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus]
device_count = len(torch_devices)
assert device_count == 1
self.wp_device = wp_devices[0]
self.torch_device = torch_devices[0]
return self.run_command([-5.0, 0, 0])
@spaces.GPU
def on_click_run_yplus(self):
wp.init()
gpus = [int(gpu) for gpu in self.cfg.gpus]
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus]
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus]
device_count = len(torch_devices)
assert device_count == 1
self.wp_device = wp_devices[0]
self.torch_device = torch_devices[0]
return self.run_command([0, 5.0, 0])
@spaces.GPU
def on_click_run_yminus(self):
wp.init()
gpus = [int(gpu) for gpu in self.cfg.gpus]
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus]
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus]
device_count = len(torch_devices)
assert device_count == 1
self.wp_device = wp_devices[0]
self.torch_device = torch_devices[0]
return self.run_command([0, -5.0, 0])
@spaces.GPU
def on_click_run_zplus(self):
wp.init()
gpus = [int(gpu) for gpu in self.cfg.gpus]
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus]
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus]
device_count = len(torch_devices)
assert device_count == 1
self.wp_device = wp_devices[0]
self.torch_device = torch_devices[0]
return self.run_command([0, 0, 5.0])
@spaces.GPU
def on_click_run_zminus(self):
wp.init()
gpus = [int(gpu) for gpu in self.cfg.gpus]
wp_devices = [wp.get_device(f'cuda:{gpu}') for gpu in gpus]
torch_devices = [torch.device(f'cuda:{gpu}') for gpu in gpus]
device_count = len(torch_devices)
assert device_count == 1
self.wp_device = wp_devices[0]
self.torch_device = torch_devices[0]
return self.run_command([0, 0, -5.0])
def launch(self, share=False):
with gr.Blocks() as app:
with gr.Row():
gr.Markdown("# Particle-Grid Neural Dynamics for Learning Deformable Object Models from RGB-D Videos")
with gr.Row():
gr.Markdown('### Project page: [https://kywind.github.io/pgnd](https://kywind.github.io/pgnd)')
with gr.Row():
# with gr.Column(scale=2):
# form_3dgs_orig = gr.Model3D(
# label='Original Gaussian Splats',
# value=None,
# )
with gr.Column(scale=2):
form_video = gr.Video(
label='Predicted video',
value=None,
format='mp4',
width=self.width,
height=self.height,
)
with gr.Column(scale=2):
form_3dgs_pred = gr.Model3D(
label='Predicted Gaussians',
height=self.height,
value=None,
clear_color=[0, 0, 0, 0],
)
# Layout
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
run_reset = gr.Button("Reset")
with gr.Row():
with gr.Column():
run_xminus = gr.Button("x-")
with gr.Column():
run_xplus = gr.Button("x+")
with gr.Row():
with gr.Column():
run_yminus = gr.Button("y-")
with gr.Column():
run_yplus = gr.Button("y+")
with gr.Row():
with gr.Column():
run_zminus = gr.Button("z-")
with gr.Column():
run_zplus = gr.Button("z+")
with gr.Column(scale=2):
_ = gr.Button(visible=False) # empty placeholder
# Set up callbacks
run_reset.click(self.reset,
inputs=[],
outputs=[form_video, form_3dgs_pred])
run_xplus.click(self.on_click_run_xplus,
inputs=[],
outputs=[form_video, form_3dgs_pred])
run_xminus.click(self.on_click_run_xminus,
inputs=[],
outputs=[form_video, form_3dgs_pred])
run_yplus.click(self.on_click_run_yplus,
inputs=[],
outputs=[form_video, form_3dgs_pred])
run_yminus.click(self.on_click_run_yminus,
inputs=[],
outputs=[form_video, form_3dgs_pred])
run_zplus.click(self.on_click_run_zplus,
inputs=[],
outputs=[form_video, form_3dgs_pred])
run_zminus.click(self.on_click_run_zminus,
inputs=[],
outputs=[form_video, form_3dgs_pred])
app.launch(share=share)
if __name__ == '__main__':
visualizer = DynamicsVisualizer()
visualizer.launch(share=True)