File size: 7,759 Bytes
6d4bcdf
 
 
 
 
 
 
 
 
1d5bb62
 
 
6d4bcdf
 
1d5bb62
 
6d4bcdf
1d5bb62
 
 
 
6d4bcdf
1d5bb62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# This file uses nvdiffrast library, which is licensed under the NVIDIA Source Code License (1-Way Commercial).
# nvdiffrast is available for non-commercial use (research or evaluation purposes only).
# For commercial use, please contact NVIDIA for licensing: https://www.nvidia.com/en-us/research/inquiries/
#
# nvdiffrast copyright: Copyright (c) 2020, NVIDIA Corporation. All rights reserved.
# Full license: https://github.com/NVlabs/nvdiffrast/blob/main/LICENSE.txt

from typing import Tuple, Union

import nvdiffrast.torch as dr
import torch
from jaxtyping import Float, Integer
from torch import Tensor


class NVDiffRasterizerContext:
    def __init__(self, context_type: str, device) -> None:
        self.device = device
        self.ctx = self.initialize_context(context_type, device)

    def initialize_context(
        self, context_type: str, device
    ) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]:
        if context_type == "gl":
            return dr.RasterizeGLContext(device=device)
        elif context_type == "cuda":
            return dr.RasterizeCudaContext(device=device)
        else:
            raise ValueError(f"Unknown rasterizer context type: {context_type}")

    def vertex_transform(
        self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"]
    ) -> Float[Tensor, "B Nv 4"]:
        with torch.amp.autocast("cuda", enabled=False):
            verts_homo = torch.cat(
                [verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1
            )
            verts_clip = torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1))
        return verts_clip

    def rasterize(
        self,
        pos: Float[Tensor, "B Nv 4"],
        tri: Integer[Tensor, "Nf 3"],
        resolution: Union[int, Tuple[int, int]],
    ):
        # rasterize in instance mode (single topology)
        return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True)

    def rasterize_one(
        self,
        pos: Float[Tensor, "Nv 4"],
        tri: Integer[Tensor, "Nf 3"],
        resolution: Union[int, Tuple[int, int]],
    ):
        # rasterize one single mesh under a single viewpoint
        rast, rast_db = self.rasterize(pos[None, ...], tri, resolution)
        return rast[0], rast_db[0]

    def antialias(
        self,
        color: Float[Tensor, "B H W C"],
        rast: Float[Tensor, "B H W 4"],
        pos: Float[Tensor, "B Nv 4"],
        tri: Integer[Tensor, "Nf 3"],
    ) -> Float[Tensor, "B H W C"]:
        return dr.antialias(color.float(), rast, pos.float(), tri.int())

    def interpolate(
        self,
        attr: Float[Tensor, "B Nv C"],
        rast: Float[Tensor, "B H W 4"],
        tri: Integer[Tensor, "Nf 3"],
        rast_db=None,
        diff_attrs=None,
    ) -> Float[Tensor, "B H W C"]:
        return dr.interpolate(
            attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs
        )

    def interpolate_one(
        self,
        attr: Float[Tensor, "Nv C"],
        rast: Float[Tensor, "B H W 4"],
        tri: Integer[Tensor, "Nf 3"],
        rast_db=None,
        diff_attrs=None,
    ) -> Float[Tensor, "B H W C"]:
        return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs)

def texture_map_to_rgb(tex_map, uv_coordinates):
    return dr.texture(tex_map.float(), uv_coordinates)

def render_rgb_from_texture_mesh_with_mask(
    ctx,
    mesh,
    tex_map: Float[Tensor, "1 H W C"],
    mvp_matrix: Float[Tensor, "batch 4 4"],
    image_height: int,
    image_width: int,
    background_color: Tensor = torch.tensor([0.0, 0.0, 0.0]),
):
    batch_size = mvp_matrix.shape[0]
    tex_map = tex_map.contiguous()
    if tex_map.dim() == 3:
        tex_map = tex_map.unsqueeze(0)  # Add batch dimension if missing

    vertex_positions_clip = ctx.vertex_transform(mesh.v_pos, mvp_matrix)
    rasterized_output, _ = ctx.rasterize(vertex_positions_clip, mesh.t_pos_idx, (image_height, image_width))
    mask = rasterized_output[..., 3:] > 0
    mask_antialiased = ctx.antialias(mask.float(), rasterized_output, vertex_positions_clip, mesh.t_pos_idx)

    interpolated_texture_coords, _ = ctx.interpolate_one(mesh._v_tex, rasterized_output, mesh._t_tex_idx)
    rgb_foreground = texture_map_to_rgb(tex_map.float(), interpolated_texture_coords)
    rgb_foreground_batched = torch.zeros(batch_size, image_height, image_width, 3).to(rgb_foreground)
    rgb_background_batched = torch.zeros(batch_size, image_height, image_width, 3).to(rgb_foreground)
    rgb_background_batched += background_color.view(1, 1, 1, 3).to(rgb_foreground)

    selector = mask[..., 0]
    rgb_foreground_batched[selector] = rgb_foreground[selector]

    # Use the anti-aliased mask for blending
    final_rgb = torch.lerp(rgb_background_batched, rgb_foreground_batched, mask_antialiased)
    final_rgb_aa = ctx.antialias(final_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx)

    return final_rgb_aa, selector


def render_geo_from_mesh(ctx, mesh, mvp_matrix, image_height, image_width):
    device = mvp_matrix.device
    vertex_positions_clip = ctx.vertex_transform(mesh.v_pos.to(device), mvp_matrix)
    rasterized_output, _ = ctx.rasterize(vertex_positions_clip, mesh.t_pos_idx.to(device), (image_height, image_width))
    interpolated_positions, _ = ctx.interpolate_one(mesh.v_pos.to(device), rasterized_output, mesh.t_pos_idx.to(device))
    interpolated_normals, _ = ctx.interpolate_one(mesh.v_normal.to(device).contiguous(), rasterized_output, mesh.t_pos_idx.to(device))

    mask = rasterized_output[..., 3:] > 0
    mask_antialiased = ctx.antialias(mask.float(), rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device))

    batch_size = mvp_matrix.shape[0]
    rgb_foreground_pos_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions)
    rgb_foreground_norm_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions)
    rgb_background_batched = torch.zeros(batch_size, image_height, image_width, 3).to(interpolated_positions)

    selector = mask[..., 0]
    rgb_foreground_pos_batched[selector] = interpolated_positions[selector]
    rgb_foreground_norm_batched[selector] = interpolated_normals[selector]

    final_pos_rgb = torch.lerp(rgb_background_batched, rgb_foreground_pos_batched, mask_antialiased)
    final_norm_rgb = torch.lerp(rgb_background_batched, rgb_foreground_norm_batched, mask_antialiased)
    final_pos_rgb_aa = ctx.antialias(final_pos_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device))
    final_norm_rgb_aa = ctx.antialias(final_norm_rgb, rasterized_output, vertex_positions_clip, mesh.t_pos_idx.to(device))

    return final_pos_rgb_aa, final_norm_rgb_aa, mask_antialiased

def rasterize_position_and_normal_maps(ctx, mesh, rasterize_height, rasterize_width):
    device = ctx.device
    # Convert mesh data to torch tensors
    mesh_v = mesh.v_pos.to(device)
    mesh_f = mesh.t_pos_idx.to(device)
    uvs_tensor = mesh._v_tex.to(device)
    indices_tensor = mesh._t_tex_idx.to(device)
    normal_v = mesh.v_normal.to(device).contiguous()

    # Interpolate mesh data
    uv_clip = uvs_tensor[None, ...] * 2.0 - 1.0
    uv_clip_padded = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., :1]), torch.ones_like(uv_clip[..., :1])), dim=-1)
    rasterized_output, _ = ctx.rasterize(uv_clip_padded, indices_tensor.int(), (rasterize_height, rasterize_width))

    # Interpolate positions.
    position_map, _ = ctx.interpolate_one(mesh_v, rasterized_output, mesh_f.int())
    normal_map, _ = ctx.interpolate_one(normal_v, rasterized_output, mesh_f.int())
    rasterization_mask = rasterized_output[..., 3:4] > 0

    return position_map, normal_map, rasterization_mask