Spaces:
Runtime error
Runtime error
import warnings | |
from abc import ABC, abstractmethod | |
from functools import partial | |
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from shap_e.models.nn.camera import DifferentiableCamera, DifferentiableProjectiveCamera | |
from shap_e.models.nn.meta import subdict | |
from shap_e.models.nn.utils import to_torch | |
from shap_e.models.query import Query | |
from shap_e.models.renderer import Renderer, get_camera_from_batch | |
from shap_e.models.volume import BoundingBoxVolume, Volume | |
from shap_e.rendering.blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR | |
from shap_e.rendering.mc import marching_cubes | |
from shap_e.rendering.torch_mesh import TorchMesh | |
from shap_e.rendering.view_data import ProjectiveCamera | |
from shap_e.util.collections import AttrDict | |
from .base import Model | |
class STFRendererBase(ABC): | |
def get_signed_distance( | |
self, | |
position: torch.Tensor, | |
params: Dict[str, torch.Tensor], | |
options: AttrDict[str, Any], | |
) -> torch.Tensor: | |
pass | |
def get_texture( | |
self, | |
position: torch.Tensor, | |
params: Dict[str, torch.Tensor], | |
options: AttrDict[str, Any], | |
) -> torch.Tensor: | |
pass | |
class STFRenderer(Renderer, STFRendererBase): | |
def __init__( | |
self, | |
sdf: Model, | |
tf: Model, | |
volume: Volume, | |
grid_size: int, | |
texture_channels: Sequence[str] = ("R", "G", "B"), | |
channel_scale: Sequence[float] = (255.0, 255.0, 255.0), | |
ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR, | |
diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR, | |
specular_color: Union[float, Tuple[float]] = 0.0, | |
output_srgb: bool = True, | |
device: torch.device = torch.device("cuda"), | |
**kwargs, | |
): | |
super().__init__(**kwargs) | |
assert isinstance(volume, BoundingBoxVolume), "cannot sample points in unknown volume" | |
self.sdf = sdf | |
self.tf = tf | |
self.volume = volume | |
self.grid_size = grid_size | |
self.texture_channels = texture_channels | |
self.channel_scale = to_torch(channel_scale).to(device) | |
self.ambient_color = ambient_color | |
self.diffuse_color = diffuse_color | |
self.specular_color = specular_color | |
self.output_srgb = output_srgb | |
self.device = device | |
self.to(device) | |
def render_views( | |
self, | |
batch: Dict, | |
params: Optional[Dict] = None, | |
options: Optional[Dict] = None, | |
) -> AttrDict: | |
params = self.update(params) | |
options = AttrDict() if not options else AttrDict(options) | |
sdf_fn = partial(self.sdf.forward_batched, params=subdict(params, "sdf")) | |
tf_fn = partial(self.tf.forward_batched, params=subdict(params, "tf")) | |
nerstf_fn = None | |
return render_views_from_stf( | |
batch, | |
options, | |
sdf_fn=sdf_fn, | |
tf_fn=tf_fn, | |
nerstf_fn=nerstf_fn, | |
volume=self.volume, | |
grid_size=self.grid_size, | |
channel_scale=self.channel_scale, | |
texture_channels=self.texture_channels, | |
ambient_color=self.ambient_color, | |
diffuse_color=self.diffuse_color, | |
specular_color=self.specular_color, | |
output_srgb=self.output_srgb, | |
device=self.device, | |
) | |
def get_signed_distance( | |
self, | |
query: Query, | |
params: Dict[str, torch.Tensor], | |
options: AttrDict[str, Any], | |
) -> torch.Tensor: | |
return self.sdf( | |
query, | |
params=subdict(params, "sdf"), | |
options=options, | |
).signed_distance | |
def get_texture( | |
self, | |
query: Query, | |
params: Dict[str, torch.Tensor], | |
options: AttrDict[str, Any], | |
) -> torch.Tensor: | |
return self.tf( | |
query, | |
params=subdict(params, "tf"), | |
options=options, | |
).channels | |
def render_views_from_stf( | |
batch: Dict, | |
options: AttrDict[str, Any], | |
*, | |
sdf_fn: Optional[Callable], | |
tf_fn: Optional[Callable], | |
nerstf_fn: Optional[Callable], | |
volume: BoundingBoxVolume, | |
grid_size: int, | |
channel_scale: torch.Tensor, | |
texture_channels: Sequence[str] = ("R", "G", "B"), | |
ambient_color: Union[float, Tuple[float]] = 0.0, | |
diffuse_color: Union[float, Tuple[float]] = 1.0, | |
specular_color: Union[float, Tuple[float]] = 0.2, | |
output_srgb: bool = False, | |
device: torch.device = torch.device("cuda"), | |
) -> AttrDict: | |
""" | |
:param batch: contains either ["poses", "camera"], or ["cameras"]. Can | |
optionally contain any of ["height", "width", "query_batch_size"] | |
:param options: controls checkpointing, caching, and rendering | |
:param sdf_fn: returns [batch_size, query_batch_size, n_output] where | |
n_output >= 1. | |
:param tf_fn: returns [batch_size, query_batch_size, n_channels] | |
:param volume: AABB volume | |
:param grid_size: SDF sampling resolution | |
:param texture_channels: what texture to predict | |
:param channel_scale: how each channel is scaled | |
:return: at least | |
channels: [batch_size, len(cameras), height, width, 3] | |
transmittance: [batch_size, len(cameras), height, width, 1] | |
aux_losses: AttrDict[str, torch.Tensor] | |
""" | |
camera, batch_size, inner_shape = get_camera_from_batch(batch) | |
inner_batch_size = int(np.prod(inner_shape)) | |
assert camera.width == camera.height, "only square views are supported" | |
assert camera.x_fov == camera.y_fov, "only square views are supported" | |
assert isinstance(camera, DifferentiableProjectiveCamera) | |
device = camera.origin.device | |
device_type = device.type | |
TO_CACHE = ["fields", "raw_meshes", "raw_signed_distance", "raw_density", "mesh_mask", "meshes"] | |
if options.cache is not None and all(key in options.cache for key in TO_CACHE): | |
fields = options.cache.fields | |
raw_meshes = options.cache.raw_meshes | |
raw_signed_distance = options.cache.raw_signed_distance | |
raw_density = options.cache.raw_density | |
mesh_mask = options.cache.mesh_mask | |
else: | |
query_batch_size = batch.get("query_batch_size", batch.get("ray_batch_size", 4096)) | |
query_points = volume_query_points(volume, grid_size) | |
fn = nerstf_fn if sdf_fn is None else sdf_fn | |
sdf_out = fn( | |
query=Query(position=query_points[None].repeat(batch_size, 1, 1)), | |
query_batch_size=query_batch_size, | |
options=options, | |
) | |
raw_signed_distance = sdf_out.signed_distance | |
raw_density = None | |
if "density" in sdf_out: | |
raw_density = sdf_out.density | |
with torch.autocast(device_type, enabled=False): | |
fields = sdf_out.signed_distance.float() | |
raw_signed_distance = sdf_out.signed_distance | |
assert ( | |
len(fields.shape) == 3 and fields.shape[-1] == 1 | |
), f"expected [meta_batch x inner_batch] SDF results, but got {fields.shape}" | |
fields = fields.reshape(batch_size, *([grid_size] * 3)) | |
# Force a negative border around the SDFs to close off all the models. | |
full_grid = torch.zeros( | |
batch_size, | |
grid_size + 2, | |
grid_size + 2, | |
grid_size + 2, | |
device=fields.device, | |
dtype=fields.dtype, | |
) | |
full_grid.fill_(-1.0) | |
full_grid[:, 1:-1, 1:-1, 1:-1] = fields | |
fields = full_grid | |
raw_meshes = [] | |
mesh_mask = [] | |
for field in fields: | |
raw_mesh = marching_cubes(field, volume.bbox_min, volume.bbox_max - volume.bbox_min) | |
if len(raw_mesh.faces) == 0: | |
# DDP deadlocks when there are unused parameters on some ranks | |
# and not others, so we make sure the field is a dependency in | |
# the graph regardless of empty meshes. | |
vertex_dependency = field.mean() | |
raw_mesh = TorchMesh( | |
verts=torch.zeros(3, 3, device=device) + vertex_dependency, | |
faces=torch.tensor([[0, 1, 2]], dtype=torch.long, device=device), | |
) | |
# Make sure we only feed back zero gradients to the field | |
# by masking out the final renderings of this mesh. | |
mesh_mask.append(False) | |
else: | |
mesh_mask.append(True) | |
raw_meshes.append(raw_mesh) | |
mesh_mask = torch.tensor(mesh_mask, device=device) | |
max_vertices = max(len(m.verts) for m in raw_meshes) | |
fn = nerstf_fn if tf_fn is None else tf_fn | |
tf_out = fn( | |
query=Query( | |
position=torch.stack( | |
[m.verts[torch.arange(0, max_vertices) % len(m.verts)] for m in raw_meshes], | |
dim=0, | |
) | |
), | |
query_batch_size=query_batch_size, | |
options=options, | |
) | |
if "cache" in options: | |
options.cache.fields = fields | |
options.cache.raw_meshes = raw_meshes | |
options.cache.raw_signed_distance = raw_signed_distance | |
options.cache.raw_density = raw_density | |
options.cache.mesh_mask = mesh_mask | |
if output_srgb: | |
tf_out.channels = _convert_srgb_to_linear(tf_out.channels) | |
# Make sure the raw meshes have colors. | |
with torch.autocast(device_type, enabled=False): | |
textures = tf_out.channels.float() | |
assert len(textures.shape) == 3 and textures.shape[-1] == len( | |
texture_channels | |
), f"expected [meta_batch x inner_batch x texture_channels] field results, but got {textures.shape}" | |
for m, texture in zip(raw_meshes, textures): | |
texture = texture[: len(m.verts)] | |
m.vertex_channels = {name: ch for name, ch in zip(texture_channels, texture.unbind(-1))} | |
args = dict( | |
options=options, | |
texture_channels=texture_channels, | |
ambient_color=ambient_color, | |
diffuse_color=diffuse_color, | |
specular_color=specular_color, | |
camera=camera, | |
batch_size=batch_size, | |
inner_batch_size=inner_batch_size, | |
inner_shape=inner_shape, | |
raw_meshes=raw_meshes, | |
tf_out=tf_out, | |
) | |
try: | |
out = _render_with_pytorch3d(**args) | |
except ModuleNotFoundError as exc: | |
warnings.warn(f"exception rendering with PyTorch3D: {exc}") | |
warnings.warn( | |
"falling back on native PyTorch renderer, which does not support full gradients" | |
) | |
out = _render_with_raycast(**args) | |
# Apply mask to prevent gradients for empty meshes. | |
reshaped_mask = mesh_mask.view([-1] + [1] * (len(out.channels.shape) - 1)) | |
out.channels = torch.where(reshaped_mask, out.channels, torch.zeros_like(out.channels)) | |
out.transmittance = torch.where( | |
reshaped_mask, out.transmittance, torch.ones_like(out.transmittance) | |
) | |
if output_srgb: | |
out.channels = _convert_linear_to_srgb(out.channels) | |
out.channels = out.channels * (1 - out.transmittance) * channel_scale.view(-1) | |
# This might be useful information to have downstream | |
out.raw_meshes = raw_meshes | |
out.fields = fields | |
out.mesh_mask = mesh_mask | |
out.raw_signed_distance = raw_signed_distance | |
out.aux_losses = AttrDict(cross_entropy=cross_entropy_sdf_loss(fields)) | |
if raw_density is not None: | |
out.raw_density = raw_density | |
return out | |
def _render_with_pytorch3d( | |
options: AttrDict, | |
texture_channels: Sequence[str], | |
ambient_color: Union[float, Tuple[float]], | |
diffuse_color: Union[float, Tuple[float]], | |
specular_color: Union[float, Tuple[float]], | |
camera: DifferentiableCamera, | |
batch_size: int, | |
inner_shape: Sequence[int], | |
inner_batch_size: int, | |
raw_meshes: List[TorchMesh], | |
tf_out: AttrDict, | |
): | |
_ = tf_out | |
# Lazy import because pytorch3d is installed lazily. | |
from shap_e.rendering.pytorch3d_util import ( | |
blender_uniform_lights, | |
convert_cameras_torch, | |
convert_meshes, | |
render_images, | |
) | |
n_channels = len(texture_channels) | |
device = camera.origin.device | |
device_type = device.type | |
with torch.autocast(device_type, enabled=False): | |
meshes = convert_meshes(raw_meshes) | |
lights = blender_uniform_lights( | |
batch_size, | |
device, | |
ambient_color=ambient_color, | |
diffuse_color=diffuse_color, | |
specular_color=specular_color, | |
) | |
# Separate camera intrinsics for each view, so that we can | |
# create a new camera for each batch of views. | |
cam_shape = [batch_size, inner_batch_size, -1] | |
position = camera.origin.reshape(cam_shape) | |
x = camera.x.reshape(cam_shape) | |
y = camera.y.reshape(cam_shape) | |
z = camera.z.reshape(cam_shape) | |
results = [] | |
for i in range(inner_batch_size): | |
sub_cams = convert_cameras_torch( | |
position[:, i], x[:, i], y[:, i], z[:, i], fov=camera.x_fov | |
) | |
imgs = render_images( | |
camera.width, | |
meshes, | |
sub_cams, | |
lights, | |
use_checkpoint=options.checkpoint_render, | |
**options.get("render_options", {}), | |
) | |
results.append(imgs) | |
views = torch.stack(results, dim=1) | |
views = views.view(batch_size, *inner_shape, camera.height, camera.width, n_channels + 1) | |
out = AttrDict( | |
channels=views[..., :-1], # [batch_size, *inner_shape, height, width, n_channels] | |
transmittance=1 - views[..., -1:], # [batch_size, *inner_shape, height, width, 1] | |
meshes=meshes, | |
) | |
return out | |
def _render_with_raycast( | |
options: AttrDict, | |
texture_channels: Sequence[str], | |
ambient_color: Union[float, Tuple[float]], | |
diffuse_color: Union[float, Tuple[float]], | |
specular_color: Union[float, Tuple[float]], | |
camera: DifferentiableCamera, | |
batch_size: int, | |
inner_shape: Sequence[int], | |
inner_batch_size: int, | |
raw_meshes: List[TorchMesh], | |
tf_out: AttrDict, | |
): | |
assert np.mean(np.array(specular_color)) == 0 | |
from shap_e.rendering.raycast.render import render_diffuse_mesh | |
from shap_e.rendering.raycast.types import TriMesh as TorchTriMesh | |
device = camera.origin.device | |
device_type = device.type | |
cam_shape = [batch_size, inner_batch_size, -1] | |
origin = camera.origin.reshape(cam_shape) | |
x = camera.x.reshape(cam_shape) | |
y = camera.y.reshape(cam_shape) | |
z = camera.z.reshape(cam_shape) | |
with torch.autocast(device_type, enabled=False): | |
all_meshes = [] | |
for i, mesh in enumerate(raw_meshes): | |
all_meshes.append( | |
TorchTriMesh( | |
faces=mesh.faces.long(), | |
vertices=mesh.verts.float(), | |
vertex_colors=tf_out.channels[i, : len(mesh.verts)].float(), | |
) | |
) | |
all_images = [] | |
for i, mesh in enumerate(all_meshes): | |
for j in range(inner_batch_size): | |
all_images.append( | |
render_diffuse_mesh( | |
camera=ProjectiveCamera( | |
origin=origin[i, j].detach().cpu().numpy(), | |
x=x[i, j].detach().cpu().numpy(), | |
y=y[i, j].detach().cpu().numpy(), | |
z=z[i, j].detach().cpu().numpy(), | |
width=camera.width, | |
height=camera.height, | |
x_fov=camera.x_fov, | |
y_fov=camera.y_fov, | |
), | |
mesh=mesh, | |
diffuse=float(np.array(diffuse_color).mean()), | |
ambient=float(np.array(ambient_color).mean()), | |
ray_batch_size=16, # low memory usage | |
checkpoint=options.checkpoint_render, | |
) | |
) | |
n_channels = len(texture_channels) | |
views = torch.stack(all_images).view( | |
batch_size, *inner_shape, camera.height, camera.width, n_channels + 1 | |
) | |
return AttrDict( | |
channels=views[..., :-1], # [batch_size, *inner_shape, height, width, n_channels] | |
transmittance=1 - views[..., -1:], # [batch_size, *inner_shape, height, width, 1] | |
meshes=all_meshes, | |
) | |
def _convert_srgb_to_linear(u: torch.Tensor) -> torch.Tensor: | |
return torch.where(u <= 0.04045, u / 12.92, ((u + 0.055) / 1.055) ** 2.4) | |
def _convert_linear_to_srgb(u: torch.Tensor) -> torch.Tensor: | |
return torch.where(u <= 0.0031308, 12.92 * u, 1.055 * (u ** (1 / 2.4)) - 0.055) | |
def cross_entropy_sdf_loss(fields: torch.Tensor): | |
logits = F.logsigmoid(fields) | |
signs = (fields > 0).float() | |
losses = [] | |
for dim in range(1, 4): | |
n = logits.shape[dim] | |
for (t_start, t_end, p_start, p_end) in [(0, -1, 1, n), (1, n, 0, -1)]: | |
targets = slice_fields(signs, dim, t_start, t_end) | |
preds = slice_fields(logits, dim, p_start, p_end) | |
losses.append( | |
F.binary_cross_entropy_with_logits(preds, targets, reduction="none") | |
.flatten(1) | |
.mean() | |
) | |
return torch.stack(losses, dim=-1).sum() | |
def slice_fields(fields: torch.Tensor, dim: int, start: int, end: int): | |
if dim == 1: | |
return fields[:, start:end] | |
elif dim == 2: | |
return fields[:, :, start:end] | |
elif dim == 3: | |
return fields[:, :, :, start:end] | |
else: | |
raise ValueError(f"cannot slice dimension {dim}") | |
def volume_query_points( | |
volume: Volume, | |
grid_size: int, | |
): | |
assert isinstance(volume, BoundingBoxVolume) | |
indices = torch.arange(grid_size**3, device=volume.bbox_min.device) | |
zs = indices % grid_size | |
ys = torch.div(indices, grid_size, rounding_mode="trunc") % grid_size | |
xs = torch.div(indices, grid_size**2, rounding_mode="trunc") % grid_size | |
combined = torch.stack([xs, ys, zs], dim=1) | |
return (combined.float() / (grid_size - 1)) * ( | |
volume.bbox_max - volume.bbox_min | |
) + volume.bbox_min | |