Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,773 Bytes
17cd746 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
#
# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
# property and proprietary rights in and to this software and related documentation.
# Any commercial use, reproduction, disclosure or distribution of this software and
# related documentation without an express license agreement from Toyota Motor Europe NV/SA
# is strictly prohibited.
#
import matplotlib.pyplot as plt
import torch
from torchvision.utils import draw_bounding_boxes, draw_keypoints
connectivity_face = (
[(i, i + 1) for i in list(range(0, 16))]
+ [(i, i + 1) for i in list(range(17, 21))]
+ [(i, i + 1) for i in list(range(22, 26))]
+ [(i, i + 1) for i in list(range(27, 30))]
+ [(i, i + 1) for i in list(range(31, 35))]
+ [(i, i + 1) for i in list(range(36, 41))]
+ [(36, 41)]
+ [(i, i + 1) for i in list(range(42, 47))]
+ [(42, 47)]
+ [(i, i + 1) for i in list(range(48, 59))]
+ [(48, 59)]
+ [(i, i + 1) for i in list(range(60, 67))]
+ [(60, 67)]
)
def plot_landmarks_2d(
img: torch.tensor,
lmks: torch.tensor,
connectivity=None,
colors="white",
unit=1,
input_float=False,
):
if input_float:
img = (img * 255).byte()
img = draw_keypoints(
img,
lmks,
connectivity=connectivity,
colors=colors,
radius=2 * unit,
width=2 * unit,
)
if input_float:
img = img.float() / 255
return img
def blend(a, b, w):
return (a * w + b * (1 - w)).byte()
if __name__ == "__main__":
from argparse import ArgumentParser
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from vhap.data.nersemble_dataset import NeRSembleDataset
parser = ArgumentParser()
parser.add_argument("--root_folder", type=str, required=True)
parser.add_argument("--subject", type=str, required=True)
parser.add_argument("--sequence", type=str, required=True)
parser.add_argument("--division", default=None)
parser.add_argument("--subset", default=None)
parser.add_argument("--scale_factor", type=float, default=1.0)
parser.add_argument("--blend_weight", type=float, default=0.6)
args = parser.parse_args()
dataset = NeRSembleDataset(
root_folder=args.root_folder,
subject=args.subject,
sequence=args.sequence,
division=args.division,
subset=args.subset,
n_downsample_rgb=2,
scale_factor=args.scale_factor,
use_landmark=True,
)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
for item in dataloader:
unit = int(item["scale_factor"][0] * 3) + 1
rgb = item["rgb"][0].permute(2, 0, 1)
vis = rgb
if "bbox_2d" in item:
bbox = item["bbox_2d"][0][:4]
tmp = draw_bounding_boxes(vis, bbox[None, ...], width=5 * unit)
vis = blend(tmp, vis, args.blend_weight)
if "lmk2d" in item:
face_landmark = item["lmk2d"][0][:, :2]
tmp = plot_landmarks_2d(
vis,
face_landmark[None, ...],
connectivity=connectivity_face,
colors="white",
unit=unit,
)
vis = blend(tmp, vis, args.blend_weight)
if "lmk2d_iris" in item:
iris_landmark = item["lmk2d_iris"][0][:, :2]
tmp = plot_landmarks_2d(
vis,
iris_landmark[None, ...],
colors="blue",
unit=unit,
)
vis = blend(tmp, vis, args.blend_weight)
vis = vis.permute(1, 2, 0).numpy()
plt.imshow(vis)
plt.draw()
while not plt.waitforbuttonpress(timeout=-1):
pass
|