Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,630 Bytes
8ed2f16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch
Attention, antialiasing step is missing in current version.
"""
import pytorch3d.ops
import torch
import torch.nn.functional as F
# import kornia
# from kornia.geometry.camera import pixel2cam
import numpy as np
from typing import List
# import nvdiffrast.torch as dr
from scipy.io import loadmat
from torch import nn
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
look_at_view_transform,
FoVPerspectiveCameras,
DirectionalLights,
RasterizationSettings,
MeshRenderer,
MeshRasterizer,
SoftPhongShader,
TexturesUV,
)
# def ndc_projection(x=0.1, n=1.0, f=50.0):
# return np.array([[n/x, 0, 0, 0],
# [ 0, n/-x, 0, 0],
# [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
# [ 0, 0, -1, 0]]).astype(np.float32)
class MeshRenderer(nn.Module):
def __init__(self,
rasterize_fov,
znear=0.1,
zfar=10,
rasterize_size=224):
super(MeshRenderer, self).__init__()
# x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear
# self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul(
# torch.diag(torch.tensor([1., -1, -1, 1])))
# self.rasterize_size = rasterize_size
# self.glctx = None
self.rasterize_size = rasterize_size
self.fov = rasterize_fov
self.znear = znear
self.zfar = zfar
self.rasterizer = None
def forward(self, vertex, tri, feat=None):
"""
Return:
mask -- torch.tensor, size (B, 1, H, W)
depth -- torch.tensor, size (B, 1, H, W)
features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
Parameters:
vertex -- torch.tensor, size (B, N, 3)
tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles
feat(optional) -- torch.tensor, size (B, C), features
"""
device = vertex.device
rsize = int(self.rasterize_size)
# ndc_proj = self.ndc_proj.to(device)
# trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
# if vertex.shape[-1] == 3:
# vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1)
# vertex[..., 1] = -vertex[..., 1]
if vertex.shape[-1] == 3:
vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1)
vertex[..., 0] = -vertex[..., 0]
# vertex_ndc = vertex @ ndc_proj.t()
# if self.glctx is None:
# self.glctx = dr.RasterizeGLContext(device=device)
# print("create glctx on device cuda:%d"%device.index)
if self.rasterizer is None:
self.rasterizer = MeshRasterizer()
print("create rasterizer on device cuda:%d"%device.index)
# ranges = None
# if isinstance(tri, List) or len(tri.shape) == 3:
# vum = vertex_ndc.shape[1]
# fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device)
# fstartidx = torch.cumsum(fnum, dim=0) - fnum
# ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu()
# for i in range(tri.shape[0]):
# tri[i] = tri[i] + i*vum
# vertex_ndc = torch.cat(vertex_ndc, dim=0)
# tri = torch.cat(tri, dim=0)
# for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
tri = tri.type(torch.int32).contiguous()
# rast_out, _ = dr.rasterize(self.glctx, vertex_ndc.contiguous(), tri, resolution=[rsize, rsize], ranges=ranges)
# depth, _ = dr.interpolate(vertex.reshape([-1,4])[...,2].unsqueeze(1).contiguous(), rast_out, tri)
# depth = depth.permute(0, 3, 1, 2)
# mask = (rast_out[..., 3] > 0).float().unsqueeze(1)
# depth = mask * depth
# image = None
# if feat is not None:
# image, _ = dr.interpolate(feat, rast_out, tri)
# image = image.permute(0, 3, 1, 2)
# image = mask * image
# rasterize
cameras = FoVPerspectiveCameras(
device=device,
fov=self.fov,
znear=self.znear,
zfar=self.zfar,
)
raster_settings = RasterizationSettings(
image_size=rsize
)
# print(vertex.shape, tri.shape)
mesh = Meshes(vertex.contiguous()[...,:3], tri.unsqueeze(0))
fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings)
rast_out = fragments.pix_to_face.squeeze(-1)
depth = fragments.zbuf
# render depth
depth = depth.permute(0, 3, 1, 2)
mask = (rast_out > 0).float().unsqueeze(1)
depth = mask * depth
image = None
if feat is not None:
attributes = feat.reshape(-1,3)[mesh.faces_packed()]
image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face,
fragments.bary_coords,
attributes)
# print(image.shape)
image = image.squeeze(-2).permute(0, 3, 1, 2)
image = mask * image
return mask, depth, image
|