File size: 3,773 Bytes
17cd746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# 
# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual 
# property and proprietary rights in and to this software and related documentation. 
# Any commercial use, reproduction, disclosure or distribution of this software and 
# related documentation without an express license agreement from Toyota Motor Europe NV/SA 
# is strictly prohibited.
#


import matplotlib.pyplot as plt
import torch
from torchvision.utils import draw_bounding_boxes, draw_keypoints


connectivity_face = (
    [(i, i + 1) for i in list(range(0, 16))]
    + [(i, i + 1) for i in list(range(17, 21))]
    + [(i, i + 1) for i in list(range(22, 26))]
    + [(i, i + 1) for i in list(range(27, 30))]
    + [(i, i + 1) for i in list(range(31, 35))]
    + [(i, i + 1) for i in list(range(36, 41))]
    + [(36, 41)]
    + [(i, i + 1) for i in list(range(42, 47))]
    + [(42, 47)]
    + [(i, i + 1) for i in list(range(48, 59))]
    + [(48, 59)]
    + [(i, i + 1) for i in list(range(60, 67))]
    + [(60, 67)]
)


def plot_landmarks_2d(
    img: torch.tensor,
    lmks: torch.tensor,
    connectivity=None,
    colors="white",
    unit=1,
    input_float=False,
):
    if input_float:
        img = (img * 255).byte()

    img = draw_keypoints(
        img,
        lmks,
        connectivity=connectivity,
        colors=colors,
        radius=2 * unit,
        width=2 * unit,
    )

    if input_float:
        img = img.float() / 255
    return img


def blend(a, b, w):
    return (a * w + b * (1 - w)).byte()


if __name__ == "__main__":
    from argparse import ArgumentParser
    from torch.utils.data import DataLoader
    from matplotlib import pyplot as plt

    from vhap.data.nersemble_dataset import NeRSembleDataset

    parser = ArgumentParser()
    parser.add_argument("--root_folder", type=str, required=True)
    parser.add_argument("--subject", type=str, required=True)
    parser.add_argument("--sequence", type=str, required=True)
    parser.add_argument("--division", default=None)
    parser.add_argument("--subset", default=None)
    parser.add_argument("--scale_factor", type=float, default=1.0)
    parser.add_argument("--blend_weight", type=float, default=0.6)
    args = parser.parse_args()

    dataset = NeRSembleDataset(
        root_folder=args.root_folder,
        subject=args.subject,
        sequence=args.sequence,
        division=args.division,
        subset=args.subset,
        n_downsample_rgb=2,
        scale_factor=args.scale_factor,
        use_landmark=True,
    )
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)

    for item in dataloader:
        unit = int(item["scale_factor"][0] * 3) + 1

        rgb = item["rgb"][0].permute(2, 0, 1)
        vis = rgb

        if "bbox_2d" in item:
            bbox = item["bbox_2d"][0][:4]
            tmp = draw_bounding_boxes(vis, bbox[None, ...], width=5 * unit)
            vis = blend(tmp, vis, args.blend_weight)

        if "lmk2d" in item:
            face_landmark = item["lmk2d"][0][:, :2]
            tmp = plot_landmarks_2d(
                vis,
                face_landmark[None, ...],
                connectivity=connectivity_face,
                colors="white",
                unit=unit,
            )
            vis = blend(tmp, vis, args.blend_weight)

        if "lmk2d_iris" in item:
            iris_landmark = item["lmk2d_iris"][0][:, :2]
            tmp = plot_landmarks_2d(
                vis,
                iris_landmark[None, ...],
                colors="blue",
                unit=unit,
            )
            vis = blend(tmp, vis, args.blend_weight)

        vis = vis.permute(1, 2, 0).numpy()
        plt.imshow(vis)
        plt.draw()
        while not plt.waitforbuttonpress(timeout=-1):
            pass