File size: 2,583 Bytes
17cd746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# 
# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 
# property and proprietary rights in and to this software and related documentation. 
# Any commercial use, reproduction, disclosure or distribution of this software and 
# related documentation without an express license agreement from Toyota Motor Europe NV/SA 
# is strictly prohibited.
#


from typing import Literal
import tyro
import numpy as np
from PIL import Image
from pathlib import Path
import torch
import nvdiffrast.torch as dr
from vhap.util.render_uvmap import render_uvmap_vtex
from vhap.model.flame import FlameHead


FLAME_UV_MASK_FOLDER = "asset/flame/uv_masks"
FLAME_UV_MASK_NPZ = "asset/flame/uv_masks.npz"


def main(
    use_opengl: bool = False,
    device: Literal['cuda', 'cpu'] = 'cuda',
):
    n_shape = 300
    n_expr = 100
    print("Initializing FLAME model")
    flame_model = FlameHead(n_shape, n_expr, add_teeth=True)

    flame_model = FlameHead(
        n_shape, 
        n_expr, 
        add_teeth=True,
    ).cuda()

    faces = flame_model.faces.int().cuda()
    verts_uv = flame_model.verts_uvs.cuda()
    # verts_uv[:, 1] = 1 - verts_uv[:, 1]
    faces_uv = flame_model.textures_idx.int().cuda()
    col_idx = faces_uv

    # Rasterizer context
    glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext()

    h, w = 2048, 2048
    resolution = (h, w)

    if not Path(FLAME_UV_MASK_FOLDER).exists():
        Path(FLAME_UV_MASK_FOLDER).mkdir(parents=True)
    
    # alpha_maps = {}
    masks = {}
    for region, vt_mask in flame_model.mask.vt:
        v_color = torch.zeros(verts_uv.shape[0], 1).to(device)  # alpha channel
        v_color[vt_mask] = 1

        alpha = render_uvmap_vtex(glctx, verts_uv, faces_uv, v_color, col_idx, resolution)[0]
        alpha = alpha.flip(0)
        # alpha_maps[region] = alpha.cpu().numpy()
        mask = (alpha > 0.5)  # to avoid overlap between hair and face
        mask = mask.squeeze(-1).cpu().numpy()
        masks[region] = mask  # (h, w)

        print(f"Saving uv mask for {region}...")
        # rgba = mask.expand(-1, -1, 4)  # (h, w, 4)
        # rgb = torch.ones_like(mask).expand(-1, -1, 3)  # (h, w, 3)
        # rgba = torch.cat([rgb, mask], dim=-1).cpu().numpy()  # (h, w, 4)
        img = mask
        img = Image.fromarray((img * 255).astype(np.uint8))
        img.save(Path(FLAME_UV_MASK_FOLDER) / f"{region}.png")
    
    print(f"Saving uv mask into: {FLAME_UV_MASK_NPZ}")
    np.savez_compressed(FLAME_UV_MASK_NPZ, **masks)


if __name__ == "__main__":
    tyro.cli(main)