File size: 5,559 Bytes
8d34f50 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# ref: https://github.com/ShunyuYao/DFA-NeRF
from numpy.core.numeric import require
from numpy.lib.function_base import quantile
import torch
import numpy as np
from facemodel import Face_3DMM
from data_loader import load_dir
from util import *
import os
import sys
import cv2
import imageio
import argparse
dir_path = os.path.dirname(os.path.realpath(__file__))
def set_requires_grad(tensor_list):
for tensor in tensor_list:
tensor.requires_grad = True
parser = argparse.ArgumentParser()
parser.add_argument(
"--path", type=str, default="obama/ori_imgs", help="idname of target person")
parser.add_argument('--img_h', type=int, default=512, help='image height')
parser.add_argument('--img_w', type=int, default=512, help='image width')
parser.add_argument('--frame_num', type=int,
default=11000, help='image number')
args = parser.parse_args()
start_id = 0
end_id = args.frame_num
lms = load_dir(args.path, start_id, end_id)
num_frames = lms.shape[0]
h, w = args.img_h, args.img_w
cxy = torch.tensor((w/2.0, h/2.0), dtype=torch.float).cuda()
id_dim, exp_dim, tex_dim, point_num = 100, 79, 100, 34650
model_3dmm = Face_3DMM(os.path.join(dir_path, '3DMM'),
id_dim, exp_dim, tex_dim, point_num)
lands_info = np.loadtxt(os.path.join(
dir_path, '3DMM', 'lands_info.txt'), dtype=np.int32)
lands_info = torch.as_tensor(lands_info).cuda()
# mesh = openmesh.read_trimesh(os.path.join(dir_path, '3DMM', 'template.obj'))
focal = 1150
id_para = lms.new_zeros((1, id_dim), requires_grad=True)
exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
tex_para = lms.new_zeros((1, tex_dim), requires_grad=True)
euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
trans = lms.new_zeros((num_frames, 3), requires_grad=True)
light_para = lms.new_zeros((num_frames, 27), requires_grad=True)
trans.data[:, 2] -= 600
focal_length = lms.new_zeros(1, requires_grad=True)
focal_length.data += focal
set_requires_grad([id_para, exp_para, tex_para,
euler_angle, trans, light_para])
sel_ids = np.arange(0, num_frames, 10)
sel_num = sel_ids.shape[0]
arg_focal = 0.0
arg_landis = 1e5
for focal in range(500, 1500, 50):
id_para = lms.new_zeros((1, id_dim), requires_grad=True)
exp_para = lms.new_zeros((sel_num, exp_dim), requires_grad=True)
euler_angle = lms.new_zeros((sel_num, 3), requires_grad=True)
trans = lms.new_zeros((sel_num, 3), requires_grad=True)
trans.data[:, 2] -= 600
focal_length = lms.new_zeros(1, requires_grad=False)
focal_length.data += focal
set_requires_grad([id_para, exp_para, euler_angle, trans])
optimizer_id = torch.optim.Adam([id_para], lr=.3)
optimizer_exp = torch.optim.Adam([exp_para], lr=.3)
optimizer_frame = torch.optim.Adam(
[euler_angle, trans], lr=.3)
iter_num = 2000
for iter in range(iter_num):
id_para_batch = id_para.expand(sel_num, -1)
geometry = model_3dmm.forward_geo_sub(
id_para_batch, exp_para, lands_info[-51:].long())
proj_geo = forward_transform(
geometry, euler_angle, trans, focal_length, cxy)
loss_lan = cal_lan_loss(
proj_geo[:, :, :2], lms[sel_ids, -51:, :].detach())
loss_regid = torch.mean(id_para*id_para)*8
loss_regexp = torch.mean(exp_para*exp_para)*0.5
loss = loss_lan + loss_regid + loss_regexp
optimizer_id.zero_grad()
optimizer_exp.zero_grad()
optimizer_frame.zero_grad()
loss.backward()
if iter > 1000:
optimizer_id.step()
optimizer_exp.step()
optimizer_frame.step()
print(focal, loss_lan.item(), torch.mean(trans[:, 2]).item())
if loss_lan.item() < arg_landis:
arg_landis = loss_lan.item()
arg_focal = focal
sel_ids = np.arange(0, num_frames)
sel_num = sel_ids.shape[0]
id_para = lms.new_zeros((1, id_dim), requires_grad=True)
exp_para = lms.new_zeros((sel_num, exp_dim), requires_grad=True)
euler_angle = lms.new_zeros((sel_num, 3), requires_grad=True)
trans = lms.new_zeros((sel_num, 3), requires_grad=True)
trans.data[:, 2] -= 600
focal_length = lms.new_zeros(1, requires_grad=False)
focal_length.data += arg_focal
set_requires_grad([id_para, exp_para, euler_angle, trans])
optimizer_id = torch.optim.Adam([id_para], lr=.3)
optimizer_exp = torch.optim.Adam([exp_para], lr=.3)
optimizer_frame = torch.optim.Adam(
[euler_angle, trans], lr=.3)
iter_num = 2000
for iter in range(iter_num):
id_para_batch = id_para.expand(sel_num, -1)
geometry = model_3dmm.forward_geo_sub(
id_para_batch, exp_para, lands_info[-51:].long())
proj_geo = forward_transform(
geometry, euler_angle, trans, focal_length, cxy)
loss_lan = cal_lan_loss(
proj_geo[:, :, :2], lms[sel_ids, -51:, :].detach())
loss_regid = torch.mean(id_para*id_para)*8
loss_regexp = torch.mean(exp_para*exp_para)*0.5
loss = loss_lan + loss_regid + loss_regexp
optimizer_id.zero_grad()
optimizer_exp.zero_grad()
optimizer_frame.zero_grad()
loss.backward()
if iter > 1000:
optimizer_id.step()
optimizer_exp.step()
optimizer_frame.step()
print(arg_focal, loss_lan.item(), torch.mean(trans[:, 2]).item())
torch.save({'id': id_para.detach().cpu(), 'exp': exp_para.detach().cpu(),
'euler': euler_angle.detach().cpu(), 'trans': trans.detach().cpu(),
'focal': focal_length.detach().cpu()}, os.path.join(os.path.dirname(args.path), 'track_params.pt'))
print('face tracking params saved')
|