Spaces:
Runtime error
Runtime error
from dataclasses import dataclass | |
from functools import lru_cache | |
from typing import Tuple | |
import torch | |
from ._mc_table import MC_TABLE | |
from .torch_mesh import TorchMesh | |
def marching_cubes( | |
field: torch.Tensor, | |
min_point: torch.Tensor, | |
size: torch.Tensor, | |
) -> TorchMesh: | |
""" | |
For a signed distance field, produce a mesh using marching cubes. | |
:param field: a 3D tensor of field values, where negative values correspond | |
to the outside of the shape. The dimensions correspond to the | |
x, y, and z directions, respectively. | |
:param min_point: a tensor of shape [3] containing the point corresponding | |
to (0, 0, 0) in the field. | |
:param size: a tensor of shape [3] containing the per-axis distance from the | |
(0, 0, 0) field corner and the (-1, -1, -1) field corner. | |
""" | |
assert len(field.shape) == 3, "input must be a 3D scalar field" | |
dev = field.device | |
grid_size = field.shape | |
grid_size_tensor = torch.tensor(grid_size).to(size) | |
lut = _lookup_table(dev) | |
# Create bitmasks between 0 and 255 (inclusive) indicating the state | |
# of the eight corners of each cube. | |
bitmasks = (field > 0).to(torch.uint8) | |
bitmasks = bitmasks[:-1, :, :] | (bitmasks[1:, :, :] << 1) | |
bitmasks = bitmasks[:, :-1, :] | (bitmasks[:, 1:, :] << 2) | |
bitmasks = bitmasks[:, :, :-1] | (bitmasks[:, :, 1:] << 4) | |
# Compute corner coordinates across the entire grid. | |
corner_coords = torch.empty(*grid_size, 3, device=dev, dtype=field.dtype) | |
corner_coords[range(grid_size[0]), :, :, 0] = torch.arange( | |
grid_size[0], device=dev, dtype=field.dtype | |
)[:, None, None] | |
corner_coords[:, range(grid_size[1]), :, 1] = torch.arange( | |
grid_size[1], device=dev, dtype=field.dtype | |
)[:, None] | |
corner_coords[:, :, range(grid_size[2]), 2] = torch.arange( | |
grid_size[2], device=dev, dtype=field.dtype | |
) | |
# Compute all vertices across all edges in the grid, even though we will | |
# throw some out later. We have (X-1)*Y*Z + X*(Y-1)*Z + X*Y*(Z-1) vertices. | |
# These are all midpoints, and don't account for interpolation (which is | |
# done later based on the used edge midpoints). | |
edge_midpoints = torch.cat( | |
[ | |
((corner_coords[:-1] + corner_coords[1:]) / 2).reshape(-1, 3), | |
((corner_coords[:, :-1] + corner_coords[:, 1:]) / 2).reshape(-1, 3), | |
((corner_coords[:, :, :-1] + corner_coords[:, :, 1:]) / 2).reshape(-1, 3), | |
], | |
dim=0, | |
) | |
# Create a flat array of [X, Y, Z] indices for each cube. | |
cube_indices = torch.zeros( | |
grid_size[0] - 1, grid_size[1] - 1, grid_size[2] - 1, 3, device=dev, dtype=torch.long | |
) | |
cube_indices[range(grid_size[0] - 1), :, :, 0] = torch.arange(grid_size[0] - 1, device=dev)[ | |
:, None, None | |
] | |
cube_indices[:, range(grid_size[1] - 1), :, 1] = torch.arange(grid_size[1] - 1, device=dev)[ | |
:, None | |
] | |
cube_indices[:, :, range(grid_size[2] - 1), 2] = torch.arange(grid_size[2] - 1, device=dev) | |
flat_cube_indices = cube_indices.reshape(-1, 3) | |
# Create a flat array mapping each cube to 12 global edge indices. | |
edge_indices = _create_flat_edge_indices(flat_cube_indices, grid_size) | |
# Apply the LUT to figure out the triangles. | |
flat_bitmasks = bitmasks.reshape( | |
-1 | |
).long() # must cast to long for indexing to believe this not a mask | |
local_tris = lut.cases[flat_bitmasks] | |
local_masks = lut.masks[flat_bitmasks] | |
# Compute the global edge indices for the triangles. | |
global_tris = torch.gather( | |
edge_indices, 1, local_tris.reshape(local_tris.shape[0], -1) | |
).reshape(local_tris.shape) | |
# Select the used triangles for each cube. | |
selected_tris = global_tris.reshape(-1, 3)[local_masks.reshape(-1)] | |
# Now we have a bunch of indices into the full list of possible vertices, | |
# but we want to reduce this list to only the used vertices. | |
used_vertex_indices = torch.unique(selected_tris.view(-1)) | |
used_edge_midpoints = edge_midpoints[used_vertex_indices] | |
old_index_to_new_index = torch.zeros(len(edge_midpoints), device=dev, dtype=torch.long) | |
old_index_to_new_index[used_vertex_indices] = torch.arange( | |
len(used_vertex_indices), device=dev, dtype=torch.long | |
) | |
# Rewrite the triangles to use the new indices | |
selected_tris = torch.gather(old_index_to_new_index, 0, selected_tris.view(-1)).reshape( | |
selected_tris.shape | |
) | |
# Compute the actual interpolated coordinates corresponding to edge midpoints. | |
v1 = torch.floor(used_edge_midpoints).to(torch.long) | |
v2 = torch.ceil(used_edge_midpoints).to(torch.long) | |
s1 = field[v1[:, 0], v1[:, 1], v1[:, 2]] | |
s2 = field[v2[:, 0], v2[:, 1], v2[:, 2]] | |
p1 = (v1.float() / (grid_size_tensor - 1)) * size + min_point | |
p2 = (v2.float() / (grid_size_tensor - 1)) * size + min_point | |
# The signs of s1 and s2 should be different. We want to find | |
# t such that t*s2 + (1-t)*s1 = 0. | |
t = (s1 / (s1 - s2))[:, None] | |
verts = t * p2 + (1 - t) * p1 | |
return TorchMesh(verts=verts, faces=selected_tris) | |
def _create_flat_edge_indices( | |
flat_cube_indices: torch.Tensor, grid_size: Tuple[int, int, int] | |
) -> torch.Tensor: | |
num_xs = (grid_size[0] - 1) * grid_size[1] * grid_size[2] | |
y_offset = num_xs | |
num_ys = grid_size[0] * (grid_size[1] - 1) * grid_size[2] | |
z_offset = num_xs + num_ys | |
return torch.stack( | |
[ | |
# Edges spanning x-axis. | |
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] | |
+ flat_cube_indices[:, 1] * grid_size[2] | |
+ flat_cube_indices[:, 2], | |
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] | |
+ (flat_cube_indices[:, 1] + 1) * grid_size[2] | |
+ flat_cube_indices[:, 2], | |
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] | |
+ flat_cube_indices[:, 1] * grid_size[2] | |
+ flat_cube_indices[:, 2] | |
+ 1, | |
flat_cube_indices[:, 0] * grid_size[1] * grid_size[2] | |
+ (flat_cube_indices[:, 1] + 1) * grid_size[2] | |
+ flat_cube_indices[:, 2] | |
+ 1, | |
# Edges spanning y-axis. | |
( | |
y_offset | |
+ flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] | |
+ flat_cube_indices[:, 1] * grid_size[2] | |
+ flat_cube_indices[:, 2] | |
), | |
( | |
y_offset | |
+ (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] | |
+ flat_cube_indices[:, 1] * grid_size[2] | |
+ flat_cube_indices[:, 2] | |
), | |
( | |
y_offset | |
+ flat_cube_indices[:, 0] * (grid_size[1] - 1) * grid_size[2] | |
+ flat_cube_indices[:, 1] * grid_size[2] | |
+ flat_cube_indices[:, 2] | |
+ 1 | |
), | |
( | |
y_offset | |
+ (flat_cube_indices[:, 0] + 1) * (grid_size[1] - 1) * grid_size[2] | |
+ flat_cube_indices[:, 1] * grid_size[2] | |
+ flat_cube_indices[:, 2] | |
+ 1 | |
), | |
# Edges spanning z-axis. | |
( | |
z_offset | |
+ flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) | |
+ flat_cube_indices[:, 1] * (grid_size[2] - 1) | |
+ flat_cube_indices[:, 2] | |
), | |
( | |
z_offset | |
+ (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) | |
+ flat_cube_indices[:, 1] * (grid_size[2] - 1) | |
+ flat_cube_indices[:, 2] | |
), | |
( | |
z_offset | |
+ flat_cube_indices[:, 0] * grid_size[1] * (grid_size[2] - 1) | |
+ (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) | |
+ flat_cube_indices[:, 2] | |
), | |
( | |
z_offset | |
+ (flat_cube_indices[:, 0] + 1) * grid_size[1] * (grid_size[2] - 1) | |
+ (flat_cube_indices[:, 1] + 1) * (grid_size[2] - 1) | |
+ flat_cube_indices[:, 2] | |
), | |
], | |
dim=-1, | |
) | |
class McLookupTable: | |
# Coordinates in triangles are represented as edge indices from 0-12 | |
# Here is an MC cell with both corner and edge indices marked. | |
# 6 + ---------- 3 ----------+ 7 | |
# /| /| | |
# 6 | 7 | | |
# / | / | | |
# 4 +--------- 2 ------------+ 5 | | |
# | 10 | | | |
# | | | 11 | |
# | | | | | |
# 8 | 2 9 | 3 | |
# | +--------- 1 --------|---+ | |
# | / | / | |
# | 4 | 5 | |
# |/ |/ | |
# +---------- 0 -----------+ | |
# 0 1 | |
cases: torch.Tensor # [256 x 5 x 3] long tensor | |
masks: torch.Tensor # [256 x 5] bool tensor | |
# if there's more than 8 GPUs and a CPU, don't bother caching | |
def _lookup_table(device: torch.device) -> McLookupTable: | |
cases = torch.zeros(256, 5, 3, device=device, dtype=torch.long) | |
masks = torch.zeros(256, 5, device=device, dtype=torch.bool) | |
edge_to_index = { | |
(0, 1): 0, | |
(2, 3): 1, | |
(4, 5): 2, | |
(6, 7): 3, | |
(0, 2): 4, | |
(1, 3): 5, | |
(4, 6): 6, | |
(5, 7): 7, | |
(0, 4): 8, | |
(1, 5): 9, | |
(2, 6): 10, | |
(3, 7): 11, | |
} | |
for i, case in enumerate(MC_TABLE): | |
for j, tri in enumerate(case): | |
for k, (c1, c2) in enumerate(zip(tri[::2], tri[1::2])): | |
cases[i, j, k] = edge_to_index[(c1, c2) if c1 < c2 else (c2, c1)] | |
masks[i, j] = True | |
return McLookupTable(cases=cases, masks=masks) | |