Spaces:
Runtime error
Runtime error
import torch | |
from gsplat import rasterization | |
from dust3r.utils.geometry import inv, geotrf | |
def render( | |
intrinsics: torch.Tensor, | |
pts3d: torch.Tensor, | |
rgbs: torch.Tensor | None = None, | |
scale: float = 0.002, | |
opacity: float = 0.95, | |
): | |
device = pts3d.device | |
batch_size = len(intrinsics) | |
img_size = pts3d.shape[1:3] | |
pts3d = pts3d.reshape(batch_size, -1, 3) | |
num_pts = pts3d.shape[1] | |
quats = torch.randn((num_pts, 4), device=device) | |
quats = quats / quats.norm(dim=-1, keepdim=True) | |
scales = scale * torch.ones((num_pts, 3), device=device) | |
opacities = opacity * torch.ones((num_pts), device=device) | |
if rgbs is not None: | |
assert rgbs.shape[1] == 3 | |
rgbs = rgbs.reshape(batch_size, 3, -1).transpose(1, 2) | |
else: | |
rgbs = torch.ones_like(pts3d[:, :, :3]) | |
rendered_rgbs = [] | |
rendered_depths = [] | |
accs = [] | |
for i in range(batch_size): | |
rgbd, acc, _ = rasterization( | |
pts3d[i], | |
quats, | |
scales, | |
opacities, | |
rgbs[i], | |
torch.eye(4, device=device)[None], | |
intrinsics[[i]], | |
width=img_size[1], | |
height=img_size[0], | |
packed=False, | |
render_mode="RGB+D", | |
) | |
rendered_depths.append(rgbd[..., 3]) | |
rendered_depths = torch.cat(rendered_depths, dim=0) | |
return rendered_rgbs, rendered_depths, accs | |
def get_render_results(gts, preds, self_view=False): | |
device = preds[0]["pts3d_in_self_view"].device | |
with torch.no_grad(): | |
depths = [] | |
gt_depths = [] | |
for i, (gt, pred) in enumerate(zip(gts, preds)): | |
if self_view: | |
camera = inv(gt["camera_pose"]).to(device) | |
intrinsics = gt["camera_intrinsics"].to(device) | |
pred = pred["pts3d_in_self_view"] | |
else: | |
camera = inv(gts[0]["camera_pose"]).to(device) | |
intrinsics = gts[0]["camera_intrinsics"].to(device) | |
pred = pred["pts3d_in_other_view"] | |
gt_img = gt["img"].to(device) | |
gt_pts3d = gt["pts3d"].to(device) | |
_, depth, _ = render(intrinsics, pred, gt_img) | |
_, gt_depth, _ = render(intrinsics, geotrf(camera, gt_pts3d), gt_img) | |
depths.append(depth) | |
gt_depths.append(gt_depth) | |
return depths, gt_depths | |