File size: 3,826 Bytes
2492d81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
import sys
sys.path.insert(0, "stylegan-encoder")
import tempfile
import warnings
import imageio
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from skimage.transform import resize
from skimage import img_as_ubyte
import torch
import torchvision.transforms as transforms
import dlib
from cog import BasePredictor, Path, Input
from demo import load_checkpoints
from demo import make_animation
from ffhq_dataset.face_alignment import image_align
from ffhq_dataset.landmarks_detector import LandmarksDetector
warnings.filterwarnings("ignore")
PREDICTOR = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
LANDMARKS_DETECTOR = LandmarksDetector("shape_predictor_68_face_landmarks.dat")
class Predictor(BasePredictor):
def setup(self):
self.device = torch.device("cuda:0")
datasets = ["vox", "taichi", "ted", "mgif"]
(
self.inpainting,
self.kp_detector,
self.dense_motion_network,
self.avd_network,
) = ({}, {}, {}, {})
for d in datasets:
(
self.inpainting[d],
self.kp_detector[d],
self.dense_motion_network[d],
self.avd_network[d],
) = load_checkpoints(
config_path=f"config/{d}-384.yaml"
if d == "ted"
else f"config/{d}-256.yaml",
checkpoint_path=f"checkpoints/{d}.pth.tar",
device=self.device,
)
def predict(
self,
source_image: Path = Input(
description="Input source image.",
),
driving_video: Path = Input(
description="Choose a micromotion.",
),
dataset_name: str = Input(
choices=["vox", "taichi", "ted", "mgif"],
default="vox",
description="Choose a dataset.",
),
) -> Path:
predict_mode = "relative" # ['standard', 'relative', 'avd']
# find_best_frame = False
pixel = 384 if dataset_name == "ted" else 256
if dataset_name == "vox":
# first run face alignment
align_image(str(source_image), 'aligned.png')
source_image = imageio.imread('aligned.png')
else:
source_image = imageio.imread(str(source_image))
reader = imageio.get_reader(str(driving_video))
fps = reader.get_meta_data()["fps"]
source_image = resize(source_image, (pixel, pixel))[..., :3]
driving_video = []
try:
for im in reader:
driving_video.append(im)
except RuntimeError:
pass
reader.close()
driving_video = [
resize(frame, (pixel, pixel))[..., :3] for frame in driving_video
]
inpainting, kp_detector, dense_motion_network, avd_network = (
self.inpainting[dataset_name],
self.kp_detector[dataset_name],
self.dense_motion_network[dataset_name],
self.avd_network[dataset_name],
)
predictions = make_animation(
source_image,
driving_video,
inpainting,
kp_detector,
dense_motion_network,
avd_network,
device="cuda:0",
mode=predict_mode,
)
# save resulting video
out_path = Path(tempfile.mkdtemp()) / "output.mp4"
imageio.mimsave(
str(out_path), [img_as_ubyte(frame) for frame in predictions], fps=fps
)
return out_path
def align_image(raw_img_path, aligned_face_path):
for i, face_landmarks in enumerate(LANDMARKS_DETECTOR.get_landmarks(raw_img_path), start=1):
image_align(raw_img_path, aligned_face_path, face_landmarks)
|