Spaces:
Running
on
Zero
Running
on
Zero
# This file uses nvdiffrast library, which is licensed under the NVIDIA Source Code License (1-Way Commercial). | |
# nvdiffrast is available for non-commercial use (research or evaluation purposes only). | |
# For commercial use, please contact NVIDIA for licensing: https://www.nvidia.com/en-us/research/inquiries/ | |
# | |
# nvdiffrast copyright: Copyright (c) 2020, NVIDIA Corporation. All rights reserved. | |
# Full license: https://github.com/NVlabs/nvdiffrast/blob/main/LICENSE.txt | |
from typing import Tuple, Union | |
import nvdiffrast.torch as dr | |
import torch | |
from jaxtyping import Float, Integer | |
from torch import Tensor | |
class NVDiffRasterizerContext: | |
def __init__(self, context_type: str, device) -> None: | |
self.device = device | |
self.ctx = self.initialize_context(context_type, device) | |
def initialize_context( | |
self, context_type: str, device | |
) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]: | |
if context_type == "gl": | |
return dr.RasterizeGLContext(device=device) | |
elif context_type == "cuda": | |
return dr.RasterizeCudaContext(device=device) | |
else: | |
raise ValueError(f"Unknown rasterizer context type: {context_type}") | |
def vertex_transform( | |
self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"] | |
) -> Float[Tensor, "B Nv 4"]: | |
with torch.amp.autocast("cuda", enabled=False): | |
verts_homo = torch.cat( | |
[verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1 | |
) | |
verts_clip = torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1)) | |
return verts_clip | |
def rasterize( | |
self, | |
pos: Float[Tensor, "B Nv 4"], | |
tri: Integer[Tensor, "Nf 3"], | |
resolution: Union[int, Tuple[int, int]], | |
): | |
# rasterize in instance mode (single topology) | |
return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True) | |
def rasterize_one( | |
self, | |
pos: Float[Tensor, "Nv 4"], | |
tri: Integer[Tensor, "Nf 3"], | |
resolution: Union[int, Tuple[int, int]], | |
): | |
# rasterize one single mesh under a single viewpoint | |
rast, rast_db = self.rasterize(pos[None, ...], tri, resolution) | |
return rast[0], rast_db[0] | |
def antialias( | |
self, | |
color: Float[Tensor, "B H W C"], | |
rast: Float[Tensor, "B H W 4"], | |
pos: Float[Tensor, "B Nv 4"], | |
tri: Integer[Tensor, "Nf 3"], | |
) -> Float[Tensor, "B H W C"]: | |
return dr.antialias(color.float(), rast, pos.float(), tri.int()) | |
def interpolate( | |
self, | |
attr: Float[Tensor, "B Nv C"], | |
rast: Float[Tensor, "B H W 4"], | |
tri: Integer[Tensor, "Nf 3"], | |
rast_db=None, | |
diff_attrs=None, | |
) -> Float[Tensor, "B H W C"]: | |
return dr.interpolate( | |
attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs | |
) | |
def interpolate_one( | |
self, | |
attr: Float[Tensor, "Nv C"], | |
rast: Float[Tensor, "B H W 4"], | |
tri: Integer[Tensor, "Nf 3"], | |
rast_db=None, | |
diff_attrs=None, | |
) -> Float[Tensor, "B H W C"]: | |
return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs) | |
def texture_map_to_rgb(tex_map, uv_coordinates): | |
return dr.texture(tex_map.float(), uv_coordinates) | |
def render_rgb_from_texture_mesh_with_mask( | |
ctx, | |
mesh, | |
tex_map: Float[Tensor, "1 H W C"], | |
mvp_matrix: Float[Tensor, "batch 4 4"], | |
image_height: int, | |
image_width: int, | |
background_color: Tensor = torch.tensor([0.0, 0.0, 0.0]), | |
): | |
batch_size = mvp_matrix.shape[0] | |
tex_map = tex_map.contiguous() | |
if tex_map.dim() == 3: | |
tex_map = tex_map.unsqueeze(0) # Add batch dimension if missing | |
vertex_positions_clip = ctx.vertex_transform(mesh.v_pos, mvp_matrix) | |
rasterized_output, _ = ctx.rasterize(vertex_positions_clip, mesh.t_pos_idx, (image_height, image_width)) | |
mask = rasterized_output[..., 3:] > 0 | |
mask_antialiased = ctx.antialias(mask.float(), rasterized_output, vertex_positions_clip, mesh.t_pos_idx) | |
interpolated_texture_coords, _ = ctx.interpolate_one(mesh._v_tex, rasterized_output, mesh._t_tex_idx) | |
rgb_foreground = texture_map_to_rgb(tex_map.float(), interpolated_texture_coords) | |
rgb_foreground_batched = torch.zeros(batch_size, image_height, image_width, 3).to(rgb_foreground) | |
rgb_background_batched = torch.zeros(batch_size, image_height, image_width, 3).to(rgb_foreground) | |
rgb_background_batched += background_color.view(1, 1, 1, 3).to(rgb_foreground) | |
selector = mask[..., 0] | |
rgb_foreground_batched[selector] = rgb_foreground[selector] | |
# Use the anti-aliased mask for blending | |
final_rgb = torch.lerp(rgb_background_batched, rgb_foreground_batched, mask_antialiased) | |
final_rgb_aa = ctx.antialias(final_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx) | |
return final_rgb_aa, selector | |
def render_geo_from_mesh(ctx, mesh, mvp_matrix, image_height, image_width): | |
device = mvp_matrix.device | |
vertex_positions_clip = ctx.vertex_transform(mesh.v_pos.to(device), mvp_matrix) | |
rasterized_output, _ = ctx.rasterize(vertex_positions_clip, mesh.t_pos_idx.to(device), (image_height, image_width)) | |
interpolated_positions, _ = ctx.interpolate_one(mesh.v_pos.to(device), rasterized_output, mesh.t_pos_idx.to(device)) | |
interpolated_normals, _ = ctx.interpolate_one(mesh.v_normal.to(device).contiguous(), rasterized_output, mesh.t_pos_idx.to(device)) | |
mask = rasterized_output[..., 3:] > 0 | |
mask_antialiased = ctx.antialias(mask.float(), rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device)) | |
batch_size = mvp_matrix.shape[0] | |
rgb_foreground_pos_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions) | |
rgb_foreground_norm_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions) | |
rgb_background_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions) | |
selector = mask[..., 0] | |
rgb_foreground_pos_batched[selector] = interpolated_positions[selector] | |
rgb_foreground_norm_batched[selector] = interpolated_normals[selector] | |
final_pos_rgb = torch.lerp(rgb_background_batched, rgb_foreground_pos_batched, mask_antialiased) | |
final_norm_rgb = torch.lerp(rgb_background_batched, rgb_foreground_norm_batched, mask_antialiased) | |
final_pos_rgb_aa = ctx.antialias(final_pos_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device)) | |
final_norm_rgb_aa = ctx.antialias(final_norm_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device)) | |
return final_pos_rgb_aa, final_norm_rgb_aa, mask_antialiased | |
def rasterize_position_and_normal_maps(ctx, mesh, rasterize_height, rasterize_width): | |
device = ctx.device | |
# Convert mesh data to torch tensors | |
mesh_v = mesh.v_pos.to(device) | |
mesh_f = mesh.t_pos_idx.to(device) | |
uvs_tensor = mesh._v_tex.to(device) | |
indices_tensor = mesh._t_tex_idx.to(device) | |
normal_v = mesh.v_normal.to(device).contiguous() | |
# Interpolate mesh data | |
uv_clip = uvs_tensor[None, ...] * 2.0 - 1.0 | |
uv_clip_padded = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., :1]), torch.ones_like(uv_clip[..., :1])), dim=-1) | |
rasterized_output, _ = ctx.rasterize(uv_clip_padded, indices_tensor.int(), (rasterize_height, rasterize_width)) | |
# Interpolate positions. | |
position_map, _ = ctx.interpolate_one(mesh_v, rasterized_output, mesh_f.int()) | |
normal_map, _ = ctx.interpolate_one(normal_v, rasterized_output, mesh_f.int()) | |
rasterization_mask = rasterized_output[..., 3:4] > 0 | |
return position_map, normal_map, rasterization_mask |