liguang0115's picture
Add initial project structure with core files, configurations, and sample images
2df809d
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# modified from DUSt3R
import PIL.Image
import numpy as np
from scipy.spatial.transform import Rotation
import torch
import cv2
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from dust3r.utils.geometry import (
geotrf,
get_med_dist_between_poses,
depthmap_to_absolute_camera_coordinates,
)
from dust3r.utils.device import to_numpy
from dust3r.utils.image import rgb, img_to_arr
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.figure import Figure
try:
import trimesh
except ImportError:
print("/!\\ module trimesh is not installed, cannot visualize results /!\\")
def float2uint8(x):
return (255.0 * x).astype(np.uint8)
def uint82float(img):
return np.ascontiguousarray(img) / 255.0
def cat_3d(vecs):
if isinstance(vecs, (np.ndarray, torch.Tensor)):
vecs = [vecs]
return np.concatenate([p.reshape(-1, 3) for p in to_numpy(vecs)])
def show_raw_pointcloud(pts3d, colors, point_size=2):
scene = trimesh.Scene()
pct = trimesh.PointCloud(cat_3d(pts3d), colors=cat_3d(colors))
scene.add_geometry(pct)
scene.show(line_settings={"point_size": point_size})
def pts3d_to_trimesh(img, pts3d, valid=None):
H, W, THREE = img.shape
assert THREE == 3
assert img.shape == pts3d.shape
vertices = pts3d.reshape(-1, 3)
idx = np.arange(len(vertices)).reshape(H, W)
idx1 = idx[:-1, :-1].ravel() # top-left corner
idx2 = idx[:-1, +1:].ravel() # right-left corner
idx3 = idx[+1:, :-1].ravel() # bottom-left corner
idx4 = idx[+1:, +1:].ravel() # bottom-right corner
faces = np.concatenate(
(
np.c_[idx1, idx2, idx3],
np.c_[
idx3, idx2, idx1
], # same triangle, but backward (cheap solution to cancel face culling)
np.c_[idx2, idx3, idx4],
np.c_[
idx4, idx3, idx2
], # same triangle, but backward (cheap solution to cancel face culling)
),
axis=0,
)
face_colors = np.concatenate(
(
img[:-1, :-1].reshape(-1, 3),
img[:-1, :-1].reshape(-1, 3),
img[+1:, +1:].reshape(-1, 3),
img[+1:, +1:].reshape(-1, 3),
),
axis=0,
)
if valid is not None:
assert valid.shape == (H, W)
valid_idxs = valid.ravel()
valid_faces = valid_idxs[faces].all(axis=-1)
faces = faces[valid_faces]
face_colors = face_colors[valid_faces]
assert len(faces) == len(face_colors)
return dict(vertices=vertices, face_colors=face_colors, faces=faces)
def cat_meshes(meshes):
vertices, faces, colors = zip(
*[(m["vertices"], m["faces"], m["face_colors"]) for m in meshes]
)
n_vertices = np.cumsum([0] + [len(v) for v in vertices])
for i in range(len(faces)):
faces[i][:] += n_vertices[i]
vertices = np.concatenate(vertices)
colors = np.concatenate(colors)
faces = np.concatenate(faces)
return dict(vertices=vertices, face_colors=colors, faces=faces)
def show_duster_pairs(view1, view2, pred1, pred2):
import matplotlib.pyplot as pl
pl.ion()
for e in range(len(view1["instance"])):
i = view1["idx"][e]
j = view2["idx"][e]
img1 = rgb(view1["img"][e])
img2 = rgb(view2["img"][e])
conf1 = pred1["conf"][e].squeeze()
conf2 = pred2["conf"][e].squeeze()
score = conf1.mean() * conf2.mean()
print(f">> Showing pair #{e} {i}-{j} {score=:g}")
pl.clf()
pl.subplot(221).imshow(img1)
pl.subplot(223).imshow(img2)
pl.subplot(222).imshow(conf1, vmin=1, vmax=30)
pl.subplot(224).imshow(conf2, vmin=1, vmax=30)
pts1 = pred1["pts3d"][e]
pts2 = pred2["pts3d_in_other_view"][e]
pl.subplots_adjust(0, 0, 1, 1, 0, 0)
if input("show pointcloud? (y/n) ") == "y":
show_raw_pointcloud(cat(pts1, pts2), cat(img1, img2), point_size=5)
def auto_cam_size(im_poses):
return 0.1 * get_med_dist_between_poses(im_poses)
class SceneViz:
def __init__(self):
self.scene = trimesh.Scene()
def add_rgbd(
self, image, depth, intrinsics=None, cam2world=None, zfar=np.inf, mask=None
):
image = img_to_arr(image)
if intrinsics is None:
H, W, THREE = image.shape
focal = max(H, W)
intrinsics = np.float32([[focal, 0, W / 2], [0, focal, H / 2], [0, 0, 1]])
pts3d = depthmap_to_pts3d(depth, intrinsics, cam2world=cam2world)
return self.add_pointcloud(
pts3d, image, mask=(depth < zfar) if mask is None else mask
)
def add_pointcloud(self, pts3d, color=(0, 0, 0), mask=None, denoise=False):
pts3d = to_numpy(pts3d)
mask = to_numpy(mask)
if not isinstance(pts3d, list):
pts3d = [pts3d.reshape(-1, 3)]
if mask is not None:
mask = [mask.ravel()]
if not isinstance(color, (tuple, list)):
color = [color.reshape(-1, 3)]
if mask is None:
mask = [slice(None)] * len(pts3d)
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
pct = trimesh.PointCloud(pts)
if isinstance(color, (list, np.ndarray, torch.Tensor)):
color = to_numpy(color)
col = np.concatenate([p[m] for p, m in zip(color, mask)])
assert col.shape == pts.shape, bb()
pct.visual.vertex_colors = uint8(col.reshape(-1, 3))
else:
assert len(color) == 3
pct.visual.vertex_colors = np.broadcast_to(uint8(color), pts.shape)
if denoise:
centroid = np.median(pct.vertices, axis=0)
dist_to_centroid = np.linalg.norm(pct.vertices - centroid, axis=-1)
dist_thr = np.quantile(dist_to_centroid, 0.99)
valid = dist_to_centroid < dist_thr
pct = trimesh.PointCloud(
pct.vertices[valid], color=pct.visual.vertex_colors[valid]
)
self.scene.add_geometry(pct)
return self
def add_rgbd(
self, image, depth, intrinsics=None, cam2world=None, zfar=np.inf, mask=None
):
if intrinsics is None:
H, W, THREE = image.shape
focal = max(H, W)
intrinsics = np.float32([[focal, 0, W / 2], [0, focal, H / 2], [0, 0, 1]])
pts3d, mask2 = depthmap_to_absolute_camera_coordinates(
depth, intrinsics, cam2world
)
mask2 &= depth < zfar
if mask is not None:
mask2 &= mask
return self.add_pointcloud(pts3d, image, mask=mask2)
def add_camera(
self,
pose_c2w,
focal=None,
color=(0, 0, 0),
image=None,
imsize=None,
cam_size=0.03,
):
pose_c2w, focal, color, image = to_numpy((pose_c2w, focal, color, image))
image = img_to_arr(image)
if isinstance(focal, np.ndarray) and focal.shape == (3, 3):
intrinsics = focal
focal = (intrinsics[0, 0] * intrinsics[1, 1]) ** 0.5
if imsize is None:
imsize = (2 * intrinsics[0, 2], 2 * intrinsics[1, 2])
add_scene_cam(
self.scene,
pose_c2w,
color,
image,
focal,
imsize=imsize,
screen_width=cam_size,
marker=None,
)
return self
def add_cameras(
self, poses, focals=None, images=None, imsizes=None, colors=None, **kw
):
get = lambda arr, idx: None if arr is None else arr[idx]
for i, pose_c2w in enumerate(poses):
self.add_camera(
pose_c2w,
get(focals, i),
image=get(images, i),
color=get(colors, i),
imsize=get(imsizes, i),
**kw,
)
return self
def show(self, point_size=2):
self.scene.show(line_settings={"point_size": point_size})
def show_raw_pointcloud_with_cams(
imgs, pts3d, mask, focals, cams2world, point_size=2, cam_size=0.05, cam_color=None
):
"""Visualization of a pointcloud with cameras
imgs = (N, H, W, 3) or N-size list of [(H,W,3), ...]
pts3d = (N, H, W, 3) or N-size list of [(H,W,3), ...]
focals = (N,) or N-size list of [focal, ...]
cams2world = (N,4,4) or N-size list of [(4,4), ...]
"""
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
pts3d = to_numpy(pts3d)
imgs = to_numpy(imgs)
focals = to_numpy(focals)
cams2world = to_numpy(cams2world)
scene = trimesh.Scene()
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
scene.add_geometry(pct)
for i, pose_c2w in enumerate(cams2world):
if isinstance(cam_color, list):
camera_edge_color = cam_color[i]
else:
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
add_scene_cam(
scene,
pose_c2w,
camera_edge_color,
imgs[i] if i < len(imgs) else None,
focals[i],
screen_width=cam_size,
)
scene.show(line_settings={"point_size": point_size})
def add_scene_cam(
scene,
pose_c2w,
edge_color,
image=None,
focal=None,
imsize=None,
screen_width=0.03,
marker=None,
):
if image is not None:
image = np.asarray(image)
H, W, THREE = image.shape
assert THREE == 3
if image.dtype != np.uint8:
image = np.uint8(255 * image)
elif imsize is not None:
W, H = imsize
elif focal is not None:
H = W = focal / 1.1
else:
H = W = 1
if isinstance(focal, np.ndarray):
focal = focal[0]
if not focal:
focal = min(H, W) * 1.1 # default value
height = max(screen_width / 10, focal * screen_width / H)
width = screen_width * 0.5**0.5
rot45 = np.eye(4)
rot45[:3, :3] = Rotation.from_euler("z", np.deg2rad(45)).as_matrix()
rot45[2, 3] = -height # set the tip of the cone = optical center
aspect_ratio = np.eye(4)
aspect_ratio[0, 0] = W / H
transform = pose_c2w @ OPENGL @ aspect_ratio @ rot45
cam = trimesh.creation.cone(width, height, sections=4) # , transform=transform)
if image is not None:
vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]])
faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]])
img = trimesh.Trimesh(vertices=vertices, faces=faces)
uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]])
img.visual = trimesh.visual.TextureVisuals(
uv_coords, image=PIL.Image.fromarray(image)
)
scene.add_geometry(img)
rot2 = np.eye(4)
rot2[:3, :3] = Rotation.from_euler("z", np.deg2rad(2)).as_matrix()
vertices = np.r_[cam.vertices, 0.95 * cam.vertices, geotrf(rot2, cam.vertices)]
vertices = geotrf(transform, vertices)
faces = []
for face in cam.faces:
if 0 in face:
continue
a, b, c = face
a2, b2, c2 = face + len(cam.vertices)
a3, b3, c3 = face + 2 * len(cam.vertices)
faces.append((a, b, b2))
faces.append((a, a2, c))
faces.append((c2, b, c))
faces.append((a, b, b3))
faces.append((a, a3, c))
faces.append((c3, b, c))
faces += [(c, b, a) for a, b, c in faces]
cam = trimesh.Trimesh(vertices=vertices, faces=faces)
cam.visual.face_colors[:, :3] = edge_color
scene.add_geometry(cam)
if marker == "o":
marker = trimesh.creation.icosphere(3, radius=screen_width / 4)
marker.vertices += pose_c2w[:3, 3]
marker.visual.face_colors[:, :3] = edge_color
scene.add_geometry(marker)
def cat(a, b):
return np.concatenate((a.reshape(-1, 3), b.reshape(-1, 3)))
OPENGL = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
CAM_COLORS = [
(255, 0, 0),
(0, 0, 255),
(0, 255, 0),
(255, 0, 255),
(255, 204, 0),
(0, 204, 204),
(128, 255, 255),
(255, 128, 255),
(255, 255, 128),
(0, 0, 0),
(128, 128, 128),
]
def uint8(colors):
if not isinstance(colors, np.ndarray):
colors = np.array(colors)
if np.issubdtype(colors.dtype, np.floating):
colors *= 255
assert 0 <= colors.min() and colors.max() < 256
return np.uint8(colors)
def segment_sky(image):
import cv2
from scipy import ndimage
image = to_numpy(image)
if np.issubdtype(image.dtype, np.floating):
image = np.uint8(255 * image.clip(min=0, max=1))
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
lower_blue = np.array([0, 0, 100])
upper_blue = np.array([30, 255, 255])
mask = cv2.inRange(hsv, lower_blue, upper_blue).view(bool)
mask |= (hsv[:, :, 1] < 10) & (hsv[:, :, 2] > 150)
mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180)
mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220)
kernel = np.ones((5, 5), np.uint8)
mask2 = ndimage.binary_opening(mask, structure=kernel)
_, labels, stats, _ = cv2.connectedComponentsWithStats(
mask2.view(np.uint8), connectivity=8
)
cc_sizes = stats[1:, cv2.CC_STAT_AREA]
order = cc_sizes.argsort()[::-1] # bigger first
i = 0
selection = []
while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2:
selection.append(1 + order[i])
i += 1
mask3 = np.in1d(labels, selection).reshape(labels.shape)
return torch.from_numpy(mask3)
def get_vertical_colorbar(h, vmin, vmax, cmap_name="jet", label=None, cbar_precision=2):
"""
:param w: pixels
:param h: pixels
:param vmin: min value
:param vmax: max value
:param cmap_name:
:param label
:return:
"""
fig = Figure(figsize=(2, 8), dpi=100)
fig.subplots_adjust(right=1.5)
canvas = FigureCanvasAgg(fig)
ax = fig.add_subplot(111)
cmap = cm.get_cmap(cmap_name)
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
tick_cnt = 6
tick_loc = np.linspace(vmin, vmax, tick_cnt)
cb1 = mpl.colorbar.ColorbarBase(
ax, cmap=cmap, norm=norm, ticks=tick_loc, orientation="vertical"
)
tick_label = [str(np.round(x, cbar_precision)) for x in tick_loc]
if cbar_precision == 0:
tick_label = [x[:-2] for x in tick_label]
cb1.set_ticklabels(tick_label)
cb1.ax.tick_params(labelsize=18, rotation=0)
if label is not None:
cb1.set_label(label)
fig.tight_layout()
canvas.draw()
s, (width, height) = canvas.print_to_buffer()
im = np.frombuffer(s, np.uint8).reshape((height, width, 4))
im = im[:, :, :3].astype(np.float32) / 255.0
if h != im.shape[0]:
w = int(im.shape[1] / im.shape[0] * h)
im = cv2.resize(im, (w, h), interpolation=cv2.INTER_AREA)
return im
def colorize_np(
x,
cmap_name="jet",
mask=None,
range=None,
append_cbar=False,
cbar_in_image=False,
cbar_precision=2,
):
"""
turn a grayscale image into a color image
:param x: input grayscale, [H, W]
:param cmap_name: the colorization method
:param mask: the mask image, [H, W]
:param range: the range for scaling, automatic if None, [min, max]
:param append_cbar: if append the color bar
:param cbar_in_image: put the color bar inside the image to keep the output image the same size as the input image
:return: colorized image, [H, W]
"""
if range is not None:
vmin, vmax = range
elif mask is not None:
vmin = np.min(x[mask][np.nonzero(x[mask])])
vmax = np.max(x[mask])
x[np.logical_not(mask)] = vmin
else:
vmin, vmax = np.percentile(x, (1, 100))
vmax += 1e-6
x = np.clip(x, vmin, vmax)
x = (x - vmin) / (vmax - vmin)
cmap = cm.get_cmap(cmap_name)
x_new = cmap(x)[:, :, :3]
if mask is not None:
mask = np.float32(mask[:, :, np.newaxis])
x_new = x_new * mask + np.ones_like(x_new) * (1.0 - mask)
cbar = get_vertical_colorbar(
h=x.shape[0],
vmin=vmin,
vmax=vmax,
cmap_name=cmap_name,
cbar_precision=cbar_precision,
)
if append_cbar:
if cbar_in_image:
x_new[:, -cbar.shape[1] :, :] = cbar
else:
x_new = np.concatenate(
(x_new, np.zeros_like(x_new[:, :5, :]), cbar), axis=1
)
return x_new
else:
return x_new
def colorize(
x, cmap_name="jet", mask=None, range=None, append_cbar=False, cbar_in_image=False
):
"""
turn a grayscale image into a color image
:param x: torch.Tensor, grayscale image, [H, W] or [B, H, W]
:param mask: torch.Tensor or None, mask image, [H, W] or [B, H, W] or None
"""
device = x.device
x = x.cpu().numpy()
if mask is not None:
mask = mask.cpu().numpy() > 0.99
kernel = np.ones((3, 3), np.uint8)
if x.ndim == 2:
x = x[None]
if mask is not None:
mask = mask[None]
out = []
for x_ in x:
if mask is not None:
mask = cv2.erode(mask.astype(np.uint8), kernel, iterations=1).astype(bool)
x_ = colorize_np(x_, cmap_name, mask, range, append_cbar, cbar_in_image)
out.append(torch.from_numpy(x_).to(device).float())
out = torch.stack(out).squeeze(0)
return out
def draw_correspondences(
imgs1, imgs2, coords1, coords2, interval=10, color_by=0, radius=2
):
"""
draw correspondences between two images
:param img1: tensor [B, H, W, 3]
:param img2: tensor [B, H, W, 3]
:param coord1: tensor [B, N, 2]
:param coord2: tensor [B, N, 2]
:param interval: int the interval between two points
:param color_by: specify the color based on image 1 or image 2, 0 or 1
:return: [B, 2*H, W, 3]
"""
batch_size = len(imgs1)
out = []
for i in range(batch_size):
img1 = imgs1[i].detach().cpu().numpy()
img2 = imgs2[i].detach().cpu().numpy()
coord1 = (
coords1[i].detach().cpu().numpy()[::interval, ::interval].reshape(-1, 2)
)
coord2 = (
coords2[i].detach().cpu().numpy()[::interval, ::interval].reshape(-1, 2)
)
img = drawMatches(
img1, img2, coord1, coord2, radius=radius, color_by=color_by, row_cat=True
)
out.append(img)
out = np.stack(out)
return out
def draw_correspondences_lines(
imgs1, imgs2, coords1, coords2, interval=10, color_by=0, radius=2
):
"""
draw correspondences between two images
:param img1: tensor [B, H, W, 3]
:param img2: tensor [B, H, W, 3]
:param coord1: tensor [B, N, 2]
:param coord2: tensor [B, N, 2]
:param interval: int the interval between two points
:param color_by: specify the color based on image 1 or image 2, 0 or 1
:return: [B, 2*H, W, 3]
"""
batch_size = len(imgs1)
out = []
for i in range(batch_size):
img1 = imgs1[i].detach().cpu().numpy()
img2 = imgs2[i].detach().cpu().numpy()
coord1 = (
coords1[i].detach().cpu().numpy()[::interval, ::interval].reshape(-1, 2)
)
coord2 = (
coords2[i].detach().cpu().numpy()[::interval, ::interval].reshape(-1, 2)
)
img = drawMatches_lines(
img1, img2, coord1, coord2, radius=radius, color_by=color_by, row_cat=True
)
out.append(img)
out = np.stack(out)
return out
def drawMatches(img1, img2, kp1, kp2, radius=2, mask=None, color_by=0, row_cat=False):
h1, w1 = img1.shape[:2]
h2, w2 = img2.shape[:2]
img1 = np.ascontiguousarray(float2uint8(img1))
img2 = np.ascontiguousarray(float2uint8(img2))
center1 = np.median(kp1, axis=0)
center2 = np.median(kp2, axis=0)
set_max = range(128)
colors = {m: i for i, m in enumerate(set_max)}
colors = {
m: (255 * np.array(plt.cm.hsv(i / float(len(colors))))[:3][::-1]).astype(
np.int32
)
for m, i in colors.items()
}
if mask is not None:
ind = np.argsort(mask)[::-1]
kp1 = kp1[ind]
kp2 = kp2[ind]
mask = mask[ind]
for i, (pt1, pt2) in enumerate(zip(kp1, kp2)):
if color_by == 0:
coord_angle = np.arctan2(pt1[1] - center1[1], pt1[0] - center1[0])
elif color_by == 1:
coord_angle = np.arctan2(pt2[1] - center2[1], pt2[0] - center2[0])
corr_color = np.int32(64 * coord_angle / np.pi) % 128
color = tuple(colors[corr_color].tolist())
if (
(pt1[0] <= w1 - 1)
and (pt1[0] >= 0)
and (pt1[1] <= h1 - 1)
and (pt1[1] >= 0)
):
img1 = cv2.circle(
img1, (int(pt1[0]), int(pt1[1])), radius, color, -1, cv2.LINE_AA
)
if (
(pt2[0] <= w2 - 1)
and (pt2[0] >= 0)
and (pt2[1] <= h2 - 1)
and (pt2[1] >= 0)
):
if mask is not None and mask[i]:
img2 = cv2.drawMarker(
img2,
(int(pt2[0]), int(pt2[1])),
color,
markerType=cv2.MARKER_CROSS,
markerSize=int(5 * radius),
thickness=int(radius / 2),
line_type=cv2.LINE_AA,
)
else:
img2 = cv2.circle(
img2, (int(pt2[0]), int(pt2[1])), radius, color, -1, cv2.LINE_AA
)
if row_cat:
whole_img = np.concatenate([img1, img2], axis=0)
else:
whole_img = np.concatenate([img1, img2], axis=1)
return whole_img
if row_cat:
return np.concatenate([img1, img2], axis=0)
return np.concatenate([img1, img2], axis=1)
def drawMatches_lines(
img1, img2, kp1, kp2, radius=2, mask=None, color_by=0, row_cat=False
):
h1, w1 = img1.shape[:2]
h2, w2 = img2.shape[:2]
img1 = np.ascontiguousarray(float2uint8(img1))
img2 = np.ascontiguousarray(float2uint8(img2))
center1 = np.median(kp1, axis=0)
center2 = np.median(kp2, axis=0)
set_max = range(128)
colors = {m: i for i, m in enumerate(set_max)}
colors = {
m: (255 * np.array(plt.cm.hsv(i / float(len(colors))))[:3][::-1]).astype(
np.int32
)
for m, i in colors.items()
}
if mask is not None:
ind = np.argsort(mask)[::-1]
kp1 = kp1[ind]
kp2 = kp2[ind]
mask = mask[ind]
if row_cat:
whole_img = np.concatenate([img1, img2], axis=0)
else:
whole_img = np.concatenate([img1, img2], axis=1)
for i, (pt1, pt2) in enumerate(zip(kp1, kp2)):
if color_by == 0:
coord_angle = np.arctan2(pt1[1] - center1[1], pt1[0] - center1[0])
elif color_by == 1:
coord_angle = np.arctan2(pt2[1] - center2[1], pt2[0] - center2[0])
corr_color = np.int32(64 * coord_angle / np.pi) % 128
color = tuple(colors[corr_color].tolist())
rand_val = np.random.rand()
if rand_val < 0.1:
if (
(pt1[0] <= w1 - 1)
and (pt1[0] >= 0)
and (pt1[1] <= h1 - 1)
and (pt1[1] >= 0)
) and (
(pt2[0] <= w2 - 1)
and (pt2[0] >= 0)
and (pt2[1] <= h2 - 1)
and (pt2[1] >= 0)
):
whole_img = cv2.circle(
whole_img,
(int(pt1[0]), int(pt1[1])),
radius,
color,
-1,
cv2.LINE_AA,
)
if row_cat:
whole_img = cv2.circle(
whole_img,
(int(pt2[0]), int(pt2[1] + h1)),
radius,
color,
-1,
cv2.LINE_AA,
)
cv2.line(
whole_img,
(int(pt1[0]), int(pt1[1])),
(int(pt2[0]), int(pt2[1] + h1)),
color,
1,
cv2.LINE_AA,
)
else:
whole_img = cv2.circle(
whole_img,
(int(pt2[0] + w1), int(pt2[1])),
radius,
color,
-1,
cv2.LINE_AA,
)
cv2.line(
whole_img,
(int(pt1[0]), int(pt1[1])),
(int(pt2[0] + w1), int(pt2[1])),
color,
1,
cv2.LINE_AA,
)
return whole_img
if row_cat:
return np.concatenate([img1, img2], axis=0)
return np.concatenate([img1, img2], axis=1)
import torch
import os
import time
import viser
def rotation_matrix_to_quaternion(R):
"""
:param R: [3, 3]
:return: [4]
"""
tr = np.trace(R)
Rxx = R[0, 0]
Ryy = R[1, 1]
Rzz = R[2, 2]
q = np.zeros(4)
q[0] = 0.5 * np.sqrt(1 + tr)
q[1] = (R[2, 1] - R[1, 2]) / (4 * q[0])
q[2] = (R[0, 2] - R[2, 0]) / (4 * q[0])
q[3] = (R[1, 0] - R[0, 1]) / (4 * q[0])
return q
class PointCloudViewer:
def __init__(self, pc_dir, device="cpu"):
self.server = viser.ViserServer()
self.server.set_up_direction("-y")
self.device = device
self.tt = lambda x: torch.from_numpy(x).float().to(device)
self.pc_dir = pc_dir
self.pcs, self.all_steps = self.read_data()
self.num_frames = len(self.all_steps)
self.fix_camera = False
self.camera_scale = self.server.add_gui_slider(
"camera_scale",
min=0.01,
max=1.0,
step=0.01,
initial_value=0.1,
)
self.camera_handles = []
def read_data(self):
pc_list = os.listdir(self.pc_dir)
pc_list.sort(key=lambda x: int(x.split(".")[0].split("_")[-1]))
pcs = {}
step_list = []
for pc_name in pc_list:
pc = np.load(os.path.join(self.pc_dir, pc_name))
step = int(pc_name.split(".")[0].split("_")[-1])
pcs.update({step: {"pc": pc}})
step_list.append(step)
return pcs, step_list
def parse_pc_data(self, pc, batch_idx=-1):
idx = batch_idx
ret_dict = {}
for i in range(len(pc.keys()) // 2):
pred_pts = pc[f"pts3d_{i+1}"][idx].reshape(-1, 3) # [N, 3]
color = pc[f"colors_{i+1}"][idx].reshape(-1, 3) # [N, 3]
ret_dict.update({f"pred_pts_{i+1}": pred_pts, f"color_{i+1}": color})
return ret_dict
def add_pc(self, step):
pc = self.pcs[step]["pc"]
pc_dict = self.parse_pc_data(pc)
for i in range(len(pc_dict.keys()) // 2):
self.server.add_point_cloud(
name=f"/frames/{step}/pred_pts_{i+1}_{step}",
points=pc_dict[f"pred_pts_{i+1}"],
colors=pc_dict[f"color_{i+1}"],
point_size=0.002,
)
if not self.fix_camera:
raise NotImplementedError
R21, T21 = find_rigid_alignment_batched(
torch.from_numpy(pc_dict["pred_pts1_2"][None]),
torch.from_numpy(pc_dict["pred_pts1_1"][None]),
)
R12, T12 = find_rigid_alignment_batched(
torch.from_numpy(pc_dict["pred_pts2_1"][None]),
torch.from_numpy(pc_dict["pred_pts2_2"][None]),
)
R21 = R21[0].numpy()
T21 = T21.numpy()
R12 = R12[0].numpy()
T12 = T12.numpy()
pred_pts1_2 = pc_dict["pred_pts1_2"] @ R21.T + T21
pred_pts2_1 = pc_dict["pred_pts2_1"] @ R12.T + T12
self.server.add_point_cloud(
name=f"/frames/{step}/pred_pts1_2_{step}",
points=pred_pts1_2,
colors=pc_dict["color1_2"],
point_size=0.002,
)
self.server.add_point_cloud(
name=f"/frames/{step}/pred_pts2_1_{step}",
points=pred_pts2_1,
colors=pc_dict["color2_1"],
point_size=0.002,
)
img1 = pc_dict["color1_1"].reshape(224, 224, 3)
img2 = pc_dict["color2_2"].reshape(224, 224, 3)
self.camera_handles.append(
self.server.add_camera_frustum(
name=f"/frames/{step}/camera1_{step}",
fov=2.0 * np.arctan(224.0 / 490.0),
aspect=1.0,
scale=self.camera_scale.value,
color=(1.0, 0, 0),
image=img1,
)
)
self.camera_handles.append(
self.server.add_camera_frustum(
name=f"/frames/{step}/camera2_{step}",
fov=2.0 * np.arctan(224.0 / 490.0),
aspect=1.0,
scale=self.camera_scale.value,
color=(0, 0, 1.0),
wxyz=rotation_matrix_to_quaternion(R21),
position=T21,
image=img2,
)
)
def animate(self):
with self.server.add_gui_folder("Playback"):
gui_timestep = self.server.add_gui_slider(
"Train Step",
min=0,
max=self.num_frames - 1,
step=1,
initial_value=0,
disabled=True,
)
gui_next_frame = self.server.add_gui_button("Next Step", disabled=True)
gui_prev_frame = self.server.add_gui_button("Prev Step", disabled=True)
gui_playing = self.server.add_gui_checkbox("Playing", False)
gui_framerate = self.server.add_gui_slider(
"FPS", min=1, max=60, step=0.1, initial_value=1
)
gui_framerate_options = self.server.add_gui_button_group(
"FPS options", ("10", "20", "30", "60")
)
@gui_next_frame.on_click
def _(_) -> None:
gui_timestep.value = (gui_timestep.value + 1) % self.num_frames
@gui_prev_frame.on_click
def _(_) -> None:
gui_timestep.value = (gui_timestep.value - 1) % self.num_frames
@gui_playing.on_update
def _(_) -> None:
gui_timestep.disabled = gui_playing.value
gui_next_frame.disabled = gui_playing.value
gui_prev_frame.disabled = gui_playing.value
@gui_framerate_options.on_click
def _(_) -> None:
gui_framerate.value = int(gui_framerate_options.value)
prev_timestep = gui_timestep.value
@gui_timestep.on_update
def _(_) -> None:
nonlocal prev_timestep
current_timestep = gui_timestep.value
with self.server.atomic():
frame_nodes[current_timestep].visible = True
frame_nodes[prev_timestep].visible = False
prev_timestep = current_timestep
self.server.flush() # Optional!
self.server.add_frame(
"/frames",
show_axes=False,
)
frame_nodes = []
for i in range(self.num_frames):
step = self.all_steps[i]
frame_nodes.append(
self.server.add_frame(
f"/frames/{step}",
show_axes=False,
)
)
self.add_pc(step)
for i, frame_node in enumerate(frame_nodes):
frame_node.visible = i == gui_timestep.value
prev_timestep = gui_timestep.value
while True:
if gui_playing.value:
gui_timestep.value = (gui_timestep.value + 1) % self.num_frames
for handle in self.camera_handles:
handle.scale = self.camera_scale.value
time.sleep(1.0 / gui_framerate.value)
def run(self):
self.animate()
while True:
time.sleep(10.0)
from sklearn.decomposition import PCA
def colorize_feature_map(x):
"""
Args:
x: torch.Tensor, [B, H, W, D]
Returns:
torch.Tensor, [B, H, W, 3]
"""
device = x.device
x = x.cpu().numpy()
out = []
for x_ in x:
x_ = colorize_feature_map_np(x_)
out.append(torch.from_numpy(x_).to(device))
out = torch.stack(out).squeeze(0)
return out
def colorize_feature_map_np(x):
"""
Args:
x: np.ndarray, [H, W, D]
"""
pca = PCA(n_components=3)
pca_features = pca.fit_transform(x.reshape(-1, x.shape[-1]))
pca_features = (pca_features - pca_features.min()) / (
pca_features.max() - pca_features.min()
)
pca_features = pca_features.reshape(x.shape[0], x.shape[1], 3)
return pca_features