Spaces:
Runtime error
Runtime error
File size: 5,032 Bytes
19c4ddf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from typing import Iterator, Optional, Tuple
import numpy as np
import torch
from shap_e.rendering.view_data import ProjectiveCamera
from ._utils import cross_product
from .types import RayCollisions, Rays, TriMesh
def cast_camera(
camera: ProjectiveCamera,
mesh: TriMesh,
ray_batch_size: Optional[int] = None,
checkpoint: Optional[bool] = None,
) -> Iterator[RayCollisions]:
pixel_indices = np.arange(camera.width * camera.height)
image_coords = np.stack([pixel_indices % camera.width, pixel_indices // camera.width], axis=1)
rays = camera.camera_rays(image_coords)
batch_size = ray_batch_size or len(rays)
checkpoint = checkpoint if checkpoint is not None else batch_size < len(rays)
for i in range(0, len(rays), batch_size):
sub_rays = rays[i : i + batch_size]
origins = torch.from_numpy(sub_rays[:, 0]).to(mesh.vertices)
directions = torch.from_numpy(sub_rays[:, 1]).to(mesh.vertices)
yield cast_rays(Rays(origins=origins, directions=directions), mesh, checkpoint=checkpoint)
def cast_rays(rays: Rays, mesh: TriMesh, checkpoint: bool = False) -> RayCollisions:
"""
Cast a batch of rays onto a mesh.
"""
if checkpoint:
collides, ray_dists, tri_indices, barycentric, normals = RayCollisionFunction.apply(
rays.origins, rays.directions, mesh.faces, mesh.vertices
)
return RayCollisions(
collides=collides,
ray_dists=ray_dists,
tri_indices=tri_indices,
barycentric=barycentric,
normals=normals,
)
# https://github.com/unixpickle/vae-textures/blob/2968549ddd4a3487f9437d4db00793324453cd59/vae_textures/render.py#L98
normals = mesh.normals() # [N x 3]
directions = rays.directions # [M x 3]
collides = (directions @ normals.T).abs() > 1e-8 # [N x M]
tris = mesh.vertices[mesh.faces] # [N x 3 x 3]
v1 = tris[:, 1] - tris[:, 0]
v2 = tris[:, 2] - tris[:, 0]
cross1 = cross_product(directions[:, None], v2[None]) # [N x M x 3]
det = torch.sum(cross1 * v1[None], dim=-1) # [N x M]
collides = torch.logical_and(collides, det.abs() > 1e-8)
invDet = 1 / det # [N x M]
o = rays.origins[:, None] - tris[None, :, 0] # [N x M x 3]
bary1 = invDet * torch.sum(o * cross1, dim=-1) # [N x M]
collides = torch.logical_and(collides, torch.logical_and(bary1 >= 0, bary1 <= 1))
cross2 = cross_product(o, v1[None]) # [N x M x 3]
bary2 = invDet * torch.sum(directions[:, None] * cross2, dim=-1) # [N x M]
collides = torch.logical_and(collides, torch.logical_and(bary2 >= 0, bary2 <= 1))
bary0 = 1 - (bary1 + bary2)
# Make sure this is in the positive part of the ray.
scale = invDet * torch.sum(v2 * cross2, dim=-1)
collides = torch.logical_and(collides, scale > 0)
# Select the nearest collision
ray_dists, tri_indices = torch.min(
torch.where(collides, scale, torch.tensor(torch.inf).to(scale)), dim=-1
) # [N]
nearest_bary = torch.stack(
[
bary0[range(len(tri_indices)), tri_indices],
bary1[range(len(tri_indices)), tri_indices],
bary2[range(len(tri_indices)), tri_indices],
],
dim=-1,
)
return RayCollisions(
collides=torch.any(collides, dim=-1),
ray_dists=ray_dists,
tri_indices=tri_indices,
barycentric=nearest_bary,
normals=normals[tri_indices],
)
class RayCollisionFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx, origins, directions, faces, vertices
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
ctx.save_for_backward(origins, directions, faces, vertices)
with torch.no_grad():
res = cast_rays(
Rays(origins=origins, directions=directions),
TriMesh(faces=faces, vertices=vertices),
checkpoint=False,
)
return (res.collides, res.ray_dists, res.tri_indices, res.barycentric, res.normals)
@staticmethod
def backward(
ctx, _collides_grad, ray_dists_grad, _tri_indices_grad, barycentric_grad, normals_grad
):
origins, directions, faces, vertices = ctx.input_tensors
origins = origins.detach().requires_grad_(True)
directions = directions.detach().requires_grad_(True)
vertices = vertices.detach().requires_grad_(True)
with torch.enable_grad():
outputs = cast_rays(
Rays(origins=origins, directions=directions),
TriMesh(faces=faces, vertices=vertices),
checkpoint=False,
)
origins_grad, directions_grad, vertices_grad = torch.autograd.grad(
(outputs.ray_dists, outputs.barycentric, outputs.normals),
(origins, directions, vertices),
(ray_dists_grad, barycentric_grad, normals_grad),
)
return (origins_grad, directions_grad, None, vertices_grad)
|