Spaces:
Runtime error
Runtime error
import copy | |
import inspect | |
from typing import Any, Callable, List, Sequence, Tuple, Union | |
import numpy as np | |
import torch | |
from pytorch3d.renderer import ( | |
BlendParams, | |
DirectionalLights, | |
FoVPerspectiveCameras, | |
MeshRasterizer, | |
MeshRenderer, | |
RasterizationSettings, | |
SoftPhongShader, | |
TexturesVertex, | |
) | |
from pytorch3d.renderer.utils import TensorProperties | |
from pytorch3d.structures import Meshes | |
from shap_e.models.nn.checkpoint import checkpoint | |
from .blender.constants import BASIC_AMBIENT_COLOR, BASIC_DIFFUSE_COLOR, UNIFORM_LIGHT_DIRECTION | |
from .torch_mesh import TorchMesh | |
from .view_data import ProjectiveCamera | |
# Using a lower value like 1e-4 seems to result in weird issues | |
# for our high-poly meshes. | |
DEFAULT_RENDER_SIGMA = 1e-5 | |
DEFAULT_RENDER_GAMMA = 1e-4 | |
def render_images( | |
image_size: int, | |
meshes: Meshes, | |
cameras: Any, | |
lights: Any, | |
sigma: float = DEFAULT_RENDER_SIGMA, | |
gamma: float = DEFAULT_RENDER_GAMMA, | |
max_faces_per_bin=100000, | |
faces_per_pixel=50, | |
bin_size=None, | |
use_checkpoint: bool = False, | |
) -> torch.Tensor: | |
if use_checkpoint: | |
# Decompose all of our arguments into a bunch of tensor lists | |
# so that autograd can keep track of what the op depends on. | |
verts_list = meshes.verts_list() | |
faces_list = meshes.faces_list() | |
assert isinstance(meshes.textures, TexturesVertex) | |
assert isinstance(lights, BidirectionalLights) | |
textures = meshes.textures.verts_features_padded() | |
light_vecs, light_fn = _deconstruct_tensor_props(lights) | |
camera_vecs, camera_fn = _deconstruct_tensor_props(cameras) | |
def ckpt_fn( | |
*args: torch.Tensor, | |
num_verts=len(verts_list), | |
num_light_vecs=len(light_vecs), | |
num_camera_vecs=len(camera_vecs), | |
light_fn=light_fn, | |
camera_fn=camera_fn, | |
faces_list=faces_list | |
): | |
args = list(args) | |
verts_list = args[:num_verts] | |
del args[:num_verts] | |
light_vecs = args[:num_light_vecs] | |
del args[:num_light_vecs] | |
camera_vecs = args[:num_camera_vecs] | |
del args[:num_camera_vecs] | |
textures = args.pop(0) | |
meshes = Meshes(verts=verts_list, faces=faces_list, textures=TexturesVertex(textures)) | |
lights = light_fn(light_vecs) | |
cameras = camera_fn(camera_vecs) | |
return render_images( | |
image_size=image_size, | |
meshes=meshes, | |
cameras=cameras, | |
lights=lights, | |
sigma=sigma, | |
gamma=gamma, | |
max_faces_per_bin=max_faces_per_bin, | |
faces_per_pixel=faces_per_pixel, | |
bin_size=bin_size, | |
use_checkpoint=False, | |
) | |
result = checkpoint(ckpt_fn, (*verts_list, *light_vecs, *camera_vecs, textures), (), True) | |
else: | |
raster_settings_soft = RasterizationSettings( | |
image_size=image_size, | |
blur_radius=np.log(1.0 / 1e-4 - 1.0) * sigma, | |
faces_per_pixel=faces_per_pixel, | |
max_faces_per_bin=max_faces_per_bin, | |
bin_size=bin_size, | |
perspective_correct=False, | |
) | |
renderer = MeshRenderer( | |
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings_soft), | |
shader=SoftPhongShader( | |
device=meshes.device, | |
cameras=cameras, | |
lights=lights, | |
blend_params=BlendParams(sigma=sigma, gamma=gamma, background_color=(0, 0, 0)), | |
), | |
) | |
result = renderer(meshes) | |
return result | |
def _deconstruct_tensor_props( | |
props: TensorProperties, | |
) -> Tuple[List[torch.Tensor], Callable[[List[torch.Tensor]], TensorProperties]]: | |
vecs = [] | |
names = [] | |
other_props = {} | |
for k in dir(props): | |
if k.startswith("__"): | |
continue | |
v = getattr(props, k) | |
if inspect.ismethod(v): | |
continue | |
if torch.is_tensor(v): | |
vecs.append(v) | |
names.append(k) | |
else: | |
other_props[k] = v | |
def recreate_fn(vecs_arg): | |
other = type(props)(device=props.device) | |
for k, v in other_props.items(): | |
setattr(other, k, copy.deepcopy(v)) | |
for name, vec in zip(names, vecs_arg): | |
setattr(other, name, vec) | |
return other | |
return vecs, recreate_fn | |
def convert_meshes(raw_meshes: Sequence[TorchMesh], default_brightness=0.8) -> Meshes: | |
meshes = Meshes( | |
verts=[mesh.verts for mesh in raw_meshes], faces=[mesh.faces for mesh in raw_meshes] | |
) | |
rgbs = [] | |
for mesh in raw_meshes: | |
if mesh.vertex_channels and all(k in mesh.vertex_channels for k in "RGB"): | |
rgbs.append(torch.stack([mesh.vertex_channels[k] for k in "RGB"], axis=-1)) | |
else: | |
rgbs.append( | |
torch.ones( | |
len(mesh.verts) * default_brightness, | |
3, | |
device=mesh.verts.device, | |
dtype=mesh.verts.dtype, | |
) | |
) | |
meshes.textures = TexturesVertex(verts_features=rgbs) | |
return meshes | |
def convert_cameras( | |
cameras: Sequence[ProjectiveCamera], device: torch.device | |
) -> FoVPerspectiveCameras: | |
Rs = [] | |
Ts = [] | |
for camera in cameras: | |
assert ( | |
camera.width == camera.height and camera.x_fov == camera.y_fov | |
), "viewports must be square" | |
assert camera.x_fov == cameras[0].x_fov, "all cameras must have same field-of-view" | |
R = np.stack([-camera.x, -camera.y, camera.z], axis=0).T | |
T = -R.T @ camera.origin | |
Rs.append(R) | |
Ts.append(T) | |
return FoVPerspectiveCameras( | |
R=np.stack(Rs, axis=0), | |
T=np.stack(Ts, axis=0), | |
fov=cameras[0].x_fov, | |
degrees=False, | |
device=device, | |
) | |
def convert_cameras_torch( | |
origins: torch.Tensor, xs: torch.Tensor, ys: torch.Tensor, zs: torch.Tensor, fov: float | |
) -> FoVPerspectiveCameras: | |
Rs = [] | |
Ts = [] | |
for origin, x, y, z in zip(origins, xs, ys, zs): | |
R = torch.stack([-x, -y, z], axis=0).T | |
T = -R.T @ origin | |
Rs.append(R) | |
Ts.append(T) | |
return FoVPerspectiveCameras( | |
R=torch.stack(Rs, dim=0), | |
T=torch.stack(Ts, dim=0), | |
fov=fov, | |
degrees=False, | |
device=origins.device, | |
) | |
def blender_uniform_lights( | |
batch_size: int, | |
device: torch.device, | |
ambient_color: Union[float, Tuple[float]] = BASIC_AMBIENT_COLOR, | |
diffuse_color: Union[float, Tuple[float]] = BASIC_DIFFUSE_COLOR, | |
specular_color: Union[float, Tuple[float]] = 0.0, | |
) -> "BidirectionalLights": | |
""" | |
Create a light that attempts to match the light used by the Blender | |
renderer when run with `--light_mode basic`. | |
""" | |
if isinstance(ambient_color, float): | |
ambient_color = (ambient_color,) * 3 | |
if isinstance(diffuse_color, float): | |
diffuse_color = (diffuse_color,) * 3 | |
if isinstance(specular_color, float): | |
specular_color = (specular_color,) * 3 | |
return BidirectionalLights( | |
ambient_color=(ambient_color,) * batch_size, | |
diffuse_color=(diffuse_color,) * batch_size, | |
specular_color=(specular_color,) * batch_size, | |
direction=(UNIFORM_LIGHT_DIRECTION,) * batch_size, | |
device=device, | |
) | |
class BidirectionalLights(DirectionalLights): | |
""" | |
Adapted from here, but effectively shines the light in both positive and negative directions: | |
https://github.com/facebookresearch/pytorch3d/blob/efea540bbcab56fccde6f4bc729d640a403dac56/pytorch3d/renderer/lighting.py#L159 | |
""" | |
def diffuse(self, normals, points=None) -> torch.Tensor: | |
return torch.maximum( | |
super().diffuse(normals, points=points), super().diffuse(-normals, points=points) | |
) | |
def specular(self, normals, points, camera_position, shininess) -> torch.Tensor: | |
return torch.maximum( | |
super().specular(normals, points, camera_position, shininess), | |
super().specular(-normals, points, camera_position, shininess), | |
) | |