File size: 2,995 Bytes
17cd746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# 
# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 
# property and proprietary rights in and to this software and related documentation. 
# Any commercial use, reproduction, disclosure or distribution of this software and 
# related documentation without an express license agreement from Toyota Motor Europe NV/SA 
# is strictly prohibited.
#


import tyro
import matplotlib.pyplot as plt
import numpy as np
import torch
import nvdiffrast.torch as dr

from vhap.model.flame import FlameHead


FLAME_TEX_PATH = "asset/flame/FLAME_texture.npz"


def transform_vt(vt):
    """Transform uv vertices to clip space"""
    xy = vt * 2 - 1
    w = torch.ones([1, vt.shape[-2], 1]).to(vt)
    z = -w  # In the clip spcae of OpenGL, the camera looks at -z
    xyzw = torch.cat([xy[None, :, :], z, w], axis=-1)    
    return xyzw

def render_uvmap_vtex(glctx, pos, pos_idx, v_color, col_idx, resolution):
    """Render uv map with vertex color"""
    pos_clip = transform_vt(pos)
    rast_out, _ = dr.rasterize(glctx, pos_clip, pos_idx, resolution)
    
    color, _ = dr.interpolate(v_color, rast_out, col_idx)
    color = dr.antialias(color, rast_out, pos_clip, pos_idx)
    return color

def render_uvmap_texmap(glctx, pos, pos_idx, verts_uv, faces_uv, tex, resolution, enable_mip=True, max_mip_level=None):
    """Render uv map with texture map"""
    pos_clip = transform_vt(pos)
    rast_out, rast_out_db = dr.rasterize(glctx, pos_clip, pos_idx, resolution)

    if enable_mip:
        texc, texd = dr.interpolate(verts_uv[None, ...], rast_out, faces_uv, rast_db=rast_out_db, diff_attrs='all')
        color = dr.texture(tex[None, ...], texc, texd, filter_mode='linear-mipmap-linear', max_mip_level=max_mip_level)
    else:
        texc, _ = dr.interpolate(verts_uv[None, ...], rast_out, faces_uv)
        color = dr.texture(tex[None, ...], texc, filter_mode='linear')
    color = dr.antialias(color, rast_out, pos_clip, pos_idx)
    return color


def main(
    use_texmap: bool = False,
    use_opengl: bool = False,
):
    n_shape = 300
    n_expr = 100
    print("Initialization FLAME model")
    flame_model = FlameHead(n_shape, n_expr)

    verts_uv = flame_model.verts_uvs.cuda()
    verts_uv[:, 1] = 1 - verts_uv[:, 1]
    faces_uv = flame_model.textures_idx.int().cuda()

    # Rasterizer context
    glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext()

    h, w = 512, 512
    resolution = (h, w)
    
    if use_texmap:
        tex = torch.from_numpy(np.load(FLAME_TEX_PATH)['mean']).cuda().float().flip(dims=[-1]) / 255
        rgb = render_uvmap_texmap(glctx, verts_uv, faces_uv, verts_uv, faces_uv, tex, resolution, enable_mip=True)
    else:
        v_color = torch.ones(verts_uv.shape[0], 3).to(verts_uv)
        col_idx = faces_uv
        rgb = render_uvmap_vtex(glctx, verts_uv, faces_uv, v_color, col_idx, resolution)
    
    plt.imshow(rgb[0, :, :].cpu())
    plt.show()


if __name__ == "__main__":
    tyro.cli(main)