# # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual # property and proprietary rights in and to this software and related documentation. # Any commercial use, reproduction, disclosure or distribution of this software and # related documentation without an express license agreement from Toyota Motor Europe NV/SA # is strictly prohibited. # import matplotlib.pyplot as plt import torch from torchvision.utils import draw_bounding_boxes, draw_keypoints connectivity_face = ( [(i, i + 1) for i in list(range(0, 16))] + [(i, i + 1) for i in list(range(17, 21))] + [(i, i + 1) for i in list(range(22, 26))] + [(i, i + 1) for i in list(range(27, 30))] + [(i, i + 1) for i in list(range(31, 35))] + [(i, i + 1) for i in list(range(36, 41))] + [(36, 41)] + [(i, i + 1) for i in list(range(42, 47))] + [(42, 47)] + [(i, i + 1) for i in list(range(48, 59))] + [(48, 59)] + [(i, i + 1) for i in list(range(60, 67))] + [(60, 67)] ) def plot_landmarks_2d( img: torch.tensor, lmks: torch.tensor, connectivity=None, colors="white", unit=1, input_float=False, ): if input_float: img = (img * 255).byte() img = draw_keypoints( img, lmks, connectivity=connectivity, colors=colors, radius=2 * unit, width=2 * unit, ) if input_float: img = img.float() / 255 return img def blend(a, b, w): return (a * w + b * (1 - w)).byte() if __name__ == "__main__": from argparse import ArgumentParser from torch.utils.data import DataLoader from matplotlib import pyplot as plt from vhap.data.nersemble_dataset import NeRSembleDataset parser = ArgumentParser() parser.add_argument("--root_folder", type=str, required=True) parser.add_argument("--subject", type=str, required=True) parser.add_argument("--sequence", type=str, required=True) parser.add_argument("--division", default=None) parser.add_argument("--subset", default=None) parser.add_argument("--scale_factor", type=float, default=1.0) parser.add_argument("--blend_weight", type=float, default=0.6) args = parser.parse_args() dataset = NeRSembleDataset( root_folder=args.root_folder, subject=args.subject, sequence=args.sequence, division=args.division, subset=args.subset, n_downsample_rgb=2, scale_factor=args.scale_factor, use_landmark=True, ) dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4) for item in dataloader: unit = int(item["scale_factor"][0] * 3) + 1 rgb = item["rgb"][0].permute(2, 0, 1) vis = rgb if "bbox_2d" in item: bbox = item["bbox_2d"][0][:4] tmp = draw_bounding_boxes(vis, bbox[None, ...], width=5 * unit) vis = blend(tmp, vis, args.blend_weight) if "lmk2d" in item: face_landmark = item["lmk2d"][0][:, :2] tmp = plot_landmarks_2d( vis, face_landmark[None, ...], connectivity=connectivity_face, colors="white", unit=unit, ) vis = blend(tmp, vis, args.blend_weight) if "lmk2d_iris" in item: iris_landmark = item["lmk2d_iris"][0][:, :2] tmp = plot_landmarks_2d( vis, iris_landmark[None, ...], colors="blue", unit=unit, ) vis = blend(tmp, vis, args.blend_weight) vis = vis.permute(1, 2, 0).numpy() plt.imshow(vis) plt.draw() while not plt.waitforbuttonpress(timeout=-1): pass