LAM / vhap /util /visualization.py
yuandong513
feat: init
17cd746
raw
history blame
3.77 kB
#
# Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual
# property and proprietary rights in and to this software and related documentation.
# Any commercial use, reproduction, disclosure or distribution of this software and
# related documentation without an express license agreement from Toyota Motor Europe NV/SA
# is strictly prohibited.
#
import matplotlib.pyplot as plt
import torch
from torchvision.utils import draw_bounding_boxes, draw_keypoints
connectivity_face = (
[(i, i + 1) for i in list(range(0, 16))]
+ [(i, i + 1) for i in list(range(17, 21))]
+ [(i, i + 1) for i in list(range(22, 26))]
+ [(i, i + 1) for i in list(range(27, 30))]
+ [(i, i + 1) for i in list(range(31, 35))]
+ [(i, i + 1) for i in list(range(36, 41))]
+ [(36, 41)]
+ [(i, i + 1) for i in list(range(42, 47))]
+ [(42, 47)]
+ [(i, i + 1) for i in list(range(48, 59))]
+ [(48, 59)]
+ [(i, i + 1) for i in list(range(60, 67))]
+ [(60, 67)]
)
def plot_landmarks_2d(
img: torch.tensor,
lmks: torch.tensor,
connectivity=None,
colors="white",
unit=1,
input_float=False,
):
if input_float:
img = (img * 255).byte()
img = draw_keypoints(
img,
lmks,
connectivity=connectivity,
colors=colors,
radius=2 * unit,
width=2 * unit,
)
if input_float:
img = img.float() / 255
return img
def blend(a, b, w):
return (a * w + b * (1 - w)).byte()
if __name__ == "__main__":
from argparse import ArgumentParser
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from vhap.data.nersemble_dataset import NeRSembleDataset
parser = ArgumentParser()
parser.add_argument("--root_folder", type=str, required=True)
parser.add_argument("--subject", type=str, required=True)
parser.add_argument("--sequence", type=str, required=True)
parser.add_argument("--division", default=None)
parser.add_argument("--subset", default=None)
parser.add_argument("--scale_factor", type=float, default=1.0)
parser.add_argument("--blend_weight", type=float, default=0.6)
args = parser.parse_args()
dataset = NeRSembleDataset(
root_folder=args.root_folder,
subject=args.subject,
sequence=args.sequence,
division=args.division,
subset=args.subset,
n_downsample_rgb=2,
scale_factor=args.scale_factor,
use_landmark=True,
)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
for item in dataloader:
unit = int(item["scale_factor"][0] * 3) + 1
rgb = item["rgb"][0].permute(2, 0, 1)
vis = rgb
if "bbox_2d" in item:
bbox = item["bbox_2d"][0][:4]
tmp = draw_bounding_boxes(vis, bbox[None, ...], width=5 * unit)
vis = blend(tmp, vis, args.blend_weight)
if "lmk2d" in item:
face_landmark = item["lmk2d"][0][:, :2]
tmp = plot_landmarks_2d(
vis,
face_landmark[None, ...],
connectivity=connectivity_face,
colors="white",
unit=unit,
)
vis = blend(tmp, vis, args.blend_weight)
if "lmk2d_iris" in item:
iris_landmark = item["lmk2d_iris"][0][:, :2]
tmp = plot_landmarks_2d(
vis,
iris_landmark[None, ...],
colors="blue",
unit=unit,
)
vis = blend(tmp, vis, args.blend_weight)
vis = vis.permute(1, 2, 0).numpy()
plt.imshow(vis)
plt.draw()
while not plt.waitforbuttonpress(timeout=-1):
pass