Spaces:
Runtime error
Runtime error
File size: 2,391 Bytes
2df809d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import torch
from gsplat import rasterization
from dust3r.utils.geometry import inv, geotrf
def render(
intrinsics: torch.Tensor,
pts3d: torch.Tensor,
rgbs: torch.Tensor | None = None,
scale: float = 0.002,
opacity: float = 0.95,
):
device = pts3d.device
batch_size = len(intrinsics)
img_size = pts3d.shape[1:3]
pts3d = pts3d.reshape(batch_size, -1, 3)
num_pts = pts3d.shape[1]
quats = torch.randn((num_pts, 4), device=device)
quats = quats / quats.norm(dim=-1, keepdim=True)
scales = scale * torch.ones((num_pts, 3), device=device)
opacities = opacity * torch.ones((num_pts), device=device)
if rgbs is not None:
assert rgbs.shape[1] == 3
rgbs = rgbs.reshape(batch_size, 3, -1).transpose(1, 2)
else:
rgbs = torch.ones_like(pts3d[:, :, :3])
rendered_rgbs = []
rendered_depths = []
accs = []
for i in range(batch_size):
rgbd, acc, _ = rasterization(
pts3d[i],
quats,
scales,
opacities,
rgbs[i],
torch.eye(4, device=device)[None],
intrinsics[[i]],
width=img_size[1],
height=img_size[0],
packed=False,
render_mode="RGB+D",
)
rendered_depths.append(rgbd[..., 3])
rendered_depths = torch.cat(rendered_depths, dim=0)
return rendered_rgbs, rendered_depths, accs
def get_render_results(gts, preds, self_view=False):
device = preds[0]["pts3d_in_self_view"].device
with torch.no_grad():
depths = []
gt_depths = []
for i, (gt, pred) in enumerate(zip(gts, preds)):
if self_view:
camera = inv(gt["camera_pose"]).to(device)
intrinsics = gt["camera_intrinsics"].to(device)
pred = pred["pts3d_in_self_view"]
else:
camera = inv(gts[0]["camera_pose"]).to(device)
intrinsics = gts[0]["camera_intrinsics"].to(device)
pred = pred["pts3d_in_other_view"]
gt_img = gt["img"].to(device)
gt_pts3d = gt["pts3d"].to(device)
_, depth, _ = render(intrinsics, pred, gt_img)
_, gt_depth, _ = render(intrinsics, geotrf(camera, gt_pts3d), gt_img)
depths.append(depth)
gt_depths.append(gt_depth)
return depths, gt_depths
|