Spaces:
Runtime error
Runtime error
import matplotlib | |
matplotlib.use('Agg') | |
import os, sys | |
import yaml | |
from argparse import ArgumentParser | |
from tqdm import tqdm | |
import imageio | |
import numpy as np | |
from skimage.transform import resize | |
from skimage import img_as_ubyte | |
import torch | |
import torch.nn.functional as F | |
from sync_batchnorm import DataParallelWithCallback | |
from modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator | |
from modules.keypoint_detector import KPDetector, HEEstimator | |
from animate import normalize_kp | |
from scipy.spatial import ConvexHull | |
if sys.version_info[0] < 3: | |
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7") | |
def load_checkpoints(config_path, checkpoint_path, gen, cpu=False): | |
with open(config_path) as f: | |
config = yaml.load(f) | |
if gen == 'original': | |
generator = OcclusionAwareGenerator(**config['model_params']['generator_params'], | |
**config['model_params']['common_params']) | |
elif gen == 'spade': | |
generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'], | |
**config['model_params']['common_params']) | |
if not cpu: | |
generator.cuda() | |
kp_detector = KPDetector(**config['model_params']['kp_detector_params'], | |
**config['model_params']['common_params']) | |
if not cpu: | |
kp_detector.cuda() | |
he_estimator = HEEstimator(**config['model_params']['he_estimator_params'], | |
**config['model_params']['common_params']) | |
if not cpu: | |
he_estimator.cuda() | |
if cpu: | |
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) | |
else: | |
checkpoint = torch.load(checkpoint_path) | |
generator.load_state_dict(checkpoint['generator']) | |
kp_detector.load_state_dict(checkpoint['kp_detector']) | |
he_estimator.load_state_dict(checkpoint['he_estimator']) | |
if not cpu: | |
generator = DataParallelWithCallback(generator) | |
kp_detector = DataParallelWithCallback(kp_detector) | |
he_estimator = DataParallelWithCallback(he_estimator) | |
generator.eval() | |
kp_detector.eval() | |
he_estimator.eval() | |
return generator, kp_detector, he_estimator | |
def headpose_pred_to_degree(pred): | |
device = pred.device | |
idx_tensor = [idx for idx in range(66)] | |
idx_tensor = torch.FloatTensor(idx_tensor).to(device) | |
pred = F.softmax(pred) | |
degree = torch.sum(pred*idx_tensor, axis=1) * 3 - 99 | |
return degree | |
''' | |
# beta version | |
def get_rotation_matrix(yaw, pitch, roll): | |
yaw = yaw / 180 * 3.14 | |
pitch = pitch / 180 * 3.14 | |
roll = roll / 180 * 3.14 | |
roll = roll.unsqueeze(1) | |
pitch = pitch.unsqueeze(1) | |
yaw = yaw.unsqueeze(1) | |
roll_mat = torch.cat([torch.ones_like(roll), torch.zeros_like(roll), torch.zeros_like(roll), | |
torch.zeros_like(roll), torch.cos(roll), -torch.sin(roll), | |
torch.zeros_like(roll), torch.sin(roll), torch.cos(roll)], dim=1) | |
roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) | |
pitch_mat = torch.cat([torch.cos(pitch), torch.zeros_like(pitch), torch.sin(pitch), | |
torch.zeros_like(pitch), torch.ones_like(pitch), torch.zeros_like(pitch), | |
-torch.sin(pitch), torch.zeros_like(pitch), torch.cos(pitch)], dim=1) | |
pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) | |
yaw_mat = torch.cat([torch.cos(yaw), -torch.sin(yaw), torch.zeros_like(yaw), | |
torch.sin(yaw), torch.cos(yaw), torch.zeros_like(yaw), | |
torch.zeros_like(yaw), torch.zeros_like(yaw), torch.ones_like(yaw)], dim=1) | |
yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) | |
rot_mat = torch.einsum('bij,bjk,bkm->bim', roll_mat, pitch_mat, yaw_mat) | |
return rot_mat | |
''' | |
def get_rotation_matrix(yaw, pitch, roll): | |
yaw = yaw / 180 * 3.14 | |
pitch = pitch / 180 * 3.14 | |
roll = roll / 180 * 3.14 | |
roll = roll.unsqueeze(1) | |
pitch = pitch.unsqueeze(1) | |
yaw = yaw.unsqueeze(1) | |
pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch), | |
torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch), | |
torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1) | |
pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3) | |
yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw), | |
torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw), | |
-torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1) | |
yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3) | |
roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll), | |
torch.sin(roll), torch.cos(roll), torch.zeros_like(roll), | |
torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1) | |
roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3) | |
rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat) | |
return rot_mat | |
def keypoint_transformation(kp_canonical, he, estimate_jacobian=True, free_view=False, yaw=0, pitch=0, roll=0): | |
kp = kp_canonical['value'] | |
if not free_view: | |
yaw, pitch, roll = he['yaw'], he['pitch'], he['roll'] | |
yaw = headpose_pred_to_degree(yaw) | |
pitch = headpose_pred_to_degree(pitch) | |
roll = headpose_pred_to_degree(roll) | |
else: | |
if yaw is not None: | |
yaw = torch.tensor([yaw]).cuda() | |
else: | |
yaw = he['yaw'] | |
yaw = headpose_pred_to_degree(yaw) | |
if pitch is not None: | |
pitch = torch.tensor([pitch]).cuda() | |
else: | |
pitch = he['pitch'] | |
pitch = headpose_pred_to_degree(pitch) | |
if roll is not None: | |
roll = torch.tensor([roll]).cuda() | |
else: | |
roll = he['roll'] | |
roll = headpose_pred_to_degree(roll) | |
t, exp = he['t'], he['exp'] | |
rot_mat = get_rotation_matrix(yaw, pitch, roll) | |
# keypoint rotation | |
kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp) | |
# keypoint translation | |
t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1) | |
kp_t = kp_rotated + t | |
# add expression deviation | |
exp = exp.view(exp.shape[0], -1, 3) | |
kp_transformed = kp_t + exp | |
if estimate_jacobian: | |
jacobian = kp_canonical['jacobian'] | |
jacobian_transformed = torch.einsum('bmp,bkps->bkms', rot_mat, jacobian) | |
else: | |
jacobian_transformed = None | |
return {'value': kp_transformed, 'jacobian': jacobian_transformed} | |
def make_animation(source_image, driving_video, generator, kp_detector, he_estimator, relative=True, adapt_movement_scale=True, estimate_jacobian=True, cpu=False, free_view=False, yaw=0, pitch=0, roll=0): | |
with torch.no_grad(): | |
predictions = [] | |
source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2) | |
if not cpu: | |
source = source.cuda() | |
driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3) | |
kp_canonical = kp_detector(source) | |
he_source = he_estimator(source) | |
he_driving_initial = he_estimator(driving[:, :, 0]) | |
kp_source = keypoint_transformation(kp_canonical, he_source, estimate_jacobian) | |
kp_driving_initial = keypoint_transformation(kp_canonical, he_driving_initial, estimate_jacobian) | |
# kp_driving_initial = keypoint_transformation(kp_canonical, he_driving_initial, free_view=free_view, yaw=yaw, pitch=pitch, roll=roll) | |
for frame_idx in tqdm(range(driving.shape[2])): | |
driving_frame = driving[:, :, frame_idx] | |
if not cpu: | |
driving_frame = driving_frame.cuda() | |
he_driving = he_estimator(driving_frame) | |
kp_driving = keypoint_transformation(kp_canonical, he_driving, estimate_jacobian, free_view=free_view, yaw=yaw, pitch=pitch, roll=roll) | |
kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving, | |
kp_driving_initial=kp_driving_initial, use_relative_movement=relative, | |
use_relative_jacobian=estimate_jacobian, adapt_movement_scale=adapt_movement_scale) | |
out = generator(source, kp_source=kp_source, kp_driving=kp_norm) | |
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) | |
return predictions | |
def find_best_frame(source, driving, cpu=False): | |
import face_alignment | |
def normalize_kp(kp): | |
kp = kp - kp.mean(axis=0, keepdims=True) | |
area = ConvexHull(kp[:, :2]).volume | |
area = np.sqrt(area) | |
kp[:, :2] = kp[:, :2] / area | |
return kp | |
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True, | |
device='cpu' if cpu else 'cuda') | |
kp_source = fa.get_landmarks(255 * source)[0] | |
kp_source = normalize_kp(kp_source) | |
norm = float('inf') | |
frame_num = 0 | |
for i, image in tqdm(enumerate(driving)): | |
kp_driving = fa.get_landmarks(255 * image)[0] | |
kp_driving = normalize_kp(kp_driving) | |
new_norm = (np.abs(kp_source - kp_driving) ** 2).sum() | |
if new_norm < norm: | |
norm = new_norm | |
frame_num = i | |
return frame_num | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument("--config", default='config/vox-256.yaml', help="path to config") | |
parser.add_argument("--checkpoint", default='', help="path to checkpoint to restore") | |
parser.add_argument("--source_image", default='', help="path to source image") | |
parser.add_argument("--driving_video", default='', help="path to driving video") | |
parser.add_argument("--result_video", default='', help="path to output") | |
parser.add_argument("--gen", default="spade", choices=["original", "spade"]) | |
parser.add_argument("--relative", dest="relative", action="store_true", help="use relative or absolute keypoint coordinates") | |
parser.add_argument("--adapt_scale", dest="adapt_scale", action="store_true", help="adapt movement scale based on convex hull of keypoints") | |
parser.add_argument("--find_best_frame", dest="find_best_frame", action="store_true", | |
help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)") | |
parser.add_argument("--best_frame", dest="best_frame", type=int, default=None, | |
help="Set frame to start from.") | |
parser.add_argument("--cpu", dest="cpu", action="store_true", help="cpu mode.") | |
parser.add_argument("--free_view", dest="free_view", action="store_true", help="control head pose") | |
parser.add_argument("--yaw", dest="yaw", type=int, default=None, help="yaw") | |
parser.add_argument("--pitch", dest="pitch", type=int, default=None, help="pitch") | |
parser.add_argument("--roll", dest="roll", type=int, default=None, help="roll") | |
parser.set_defaults(relative=False) | |
parser.set_defaults(adapt_scale=False) | |
parser.set_defaults(free_view=False) | |
opt = parser.parse_args() | |
source_image = imageio.imread(opt.source_image) | |
reader = imageio.get_reader(opt.driving_video) | |
fps = reader.get_meta_data()['fps'] | |
driving_video = [] | |
try: | |
for im in reader: | |
driving_video.append(im) | |
except RuntimeError: | |
pass | |
reader.close() | |
source_image = resize(source_image, (256, 256))[..., :3] | |
driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video] | |
generator, kp_detector, he_estimator = load_checkpoints(config_path=opt.config, checkpoint_path=opt.checkpoint, gen=opt.gen, cpu=opt.cpu) | |
with open(opt.config) as f: | |
config = yaml.load(f) | |
estimate_jacobian = config['model_params']['common_params']['estimate_jacobian'] | |
print(f'estimate jacobian: {estimate_jacobian}') | |
if opt.find_best_frame or opt.best_frame is not None: | |
i = opt.best_frame if opt.best_frame is not None else find_best_frame(source_image, driving_video, cpu=opt.cpu) | |
print ("Best frame: " + str(i)) | |
driving_forward = driving_video[i:] | |
driving_backward = driving_video[:(i+1)][::-1] | |
predictions_forward = make_animation(source_image, driving_forward, generator, kp_detector, he_estimator, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, estimate_jacobian=estimate_jacobian, cpu=opt.cpu, free_view=opt.free_view, yaw=opt.yaw, pitch=opt.pitch, roll=opt.roll) | |
predictions_backward = make_animation(source_image, driving_backward, generator, kp_detector, he_estimator, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, estimate_jacobian=estimate_jacobian, cpu=opt.cpu, free_view=opt.free_view, yaw=opt.yaw, pitch=opt.pitch, roll=opt.roll) | |
predictions = predictions_backward[::-1] + predictions_forward[1:] | |
else: | |
predictions = make_animation(source_image, driving_video, generator, kp_detector, he_estimator, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, estimate_jacobian=estimate_jacobian, cpu=opt.cpu, free_view=opt.free_view, yaw=opt.yaw, pitch=opt.pitch, roll=opt.roll) | |
imageio.mimsave(opt.result_video, [img_as_ubyte(frame) for frame in predictions], fps=fps) | |