File size: 5,630 Bytes
8ed2f16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch
    Attention, antialiasing step is missing in current version.
"""

import pytorch3d.ops
import torch
import torch.nn.functional as F
# import kornia
# from kornia.geometry.camera import pixel2cam
import numpy as np
from typing import List
# import nvdiffrast.torch as dr
from scipy.io import loadmat
from torch import nn

from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVPerspectiveCameras,
    DirectionalLights,
    RasterizationSettings,
    MeshRenderer,
    MeshRasterizer,
    SoftPhongShader,
    TexturesUV,
)

# def ndc_projection(x=0.1, n=1.0, f=50.0):
#     return np.array([[n/x,    0,            0,              0],
#                      [  0, n/-x,            0,              0],
#                      [  0,    0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
#                      [  0,    0,           -1,              0]]).astype(np.float32)

class MeshRenderer(nn.Module):
    def __init__(self,
                rasterize_fov,
                znear=0.1,
                zfar=10, 
                rasterize_size=224):
        super(MeshRenderer, self).__init__()

        # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear
        # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul(
        #         torch.diag(torch.tensor([1., -1, -1, 1])))
        # self.rasterize_size = rasterize_size
        # self.glctx = None

        self.rasterize_size = rasterize_size
        self.fov = rasterize_fov
        self.znear = znear
        self.zfar = zfar

        self.rasterizer = None
    
    def forward(self, vertex, tri, feat=None):
        """
        Return:
            mask               -- torch.tensor, size (B, 1, H, W)
            depth              -- torch.tensor, size (B, 1, H, W)
            features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None

        Parameters:
            vertex          -- torch.tensor, size (B, N, 3)
            tri             -- torch.tensor, size (B, M, 3) or (M, 3), triangles
            feat(optional)  -- torch.tensor, size (B, C), features
        """
        device = vertex.device
        rsize = int(self.rasterize_size)
        # ndc_proj = self.ndc_proj.to(device)
        # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
        # if vertex.shape[-1] == 3:
        #     vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1)
        #     vertex[..., 1] = -vertex[..., 1] 
        if vertex.shape[-1] == 3:
            vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1)
            vertex[..., 0] = -vertex[..., 0]


        # vertex_ndc = vertex @ ndc_proj.t()
        # if self.glctx is None:
        #     self.glctx = dr.RasterizeGLContext(device=device)
        #     print("create glctx on device cuda:%d"%device.index)
        if self.rasterizer is None:
            self.rasterizer = MeshRasterizer()
            print("create rasterizer on device cuda:%d"%device.index)
        
        # ranges = None
        # if isinstance(tri, List) or len(tri.shape) == 3:
        #     vum = vertex_ndc.shape[1]
        #     fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device) 
        #     fstartidx = torch.cumsum(fnum, dim=0) - fnum 
        #     ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu()
        #     for i in range(tri.shape[0]):
        #         tri[i] = tri[i] + i*vum
        #     vertex_ndc = torch.cat(vertex_ndc, dim=0)
        #     tri = torch.cat(tri, dim=0)

        # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
        tri = tri.type(torch.int32).contiguous()
        # rast_out, _ = dr.rasterize(self.glctx, vertex_ndc.contiguous(), tri, resolution=[rsize, rsize], ranges=ranges)

        # depth, _ = dr.interpolate(vertex.reshape([-1,4])[...,2].unsqueeze(1).contiguous(), rast_out, tri) 
        # depth = depth.permute(0, 3, 1, 2)
        # mask =  (rast_out[..., 3] > 0).float().unsqueeze(1)
        # depth = mask * depth
        

        # image = None
        # if feat is not None:
        #     image, _ = dr.interpolate(feat, rast_out, tri)
        #     image = image.permute(0, 3, 1, 2)
        #     image = mask * image
        # rasterize
        cameras = FoVPerspectiveCameras(
            device=device,
            fov=self.fov,
            znear=self.znear,
            zfar=self.zfar,
        )

        raster_settings = RasterizationSettings(
            image_size=rsize
        )

        # print(vertex.shape, tri.shape)
        mesh = Meshes(vertex.contiguous()[...,:3], tri.unsqueeze(0))

        fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings)
        rast_out = fragments.pix_to_face.squeeze(-1)
        depth = fragments.zbuf

        # render depth
        depth = depth.permute(0, 3, 1, 2)
        mask = (rast_out > 0).float().unsqueeze(1)
        depth = mask * depth
        

        image = None
        if feat is not None:
            attributes = feat.reshape(-1,3)[mesh.faces_packed()]
            image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face,
                                                      fragments.bary_coords,
                                                      attributes)
            # print(image.shape)
            image = image.squeeze(-2).permute(0, 3, 1, 2)
            image = mask * image
        
        return mask, depth, image