File size: 5,559 Bytes
8d34f50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# ref: https://github.com/ShunyuYao/DFA-NeRF
from numpy.core.numeric import require
from numpy.lib.function_base import quantile
import torch
import numpy as np
from facemodel import Face_3DMM
from data_loader import load_dir
from util import *
import os
import sys
import cv2
import imageio
import argparse

dir_path = os.path.dirname(os.path.realpath(__file__))


def set_requires_grad(tensor_list):
    for tensor in tensor_list:
        tensor.requires_grad = True


parser = argparse.ArgumentParser()
parser.add_argument(
    "--path", type=str, default="obama/ori_imgs", help="idname of target person")
parser.add_argument('--img_h', type=int, default=512, help='image height')
parser.add_argument('--img_w', type=int, default=512, help='image width')
parser.add_argument('--frame_num', type=int,
                    default=11000, help='image number')
args = parser.parse_args()
start_id = 0
end_id = args.frame_num

lms = load_dir(args.path, start_id, end_id)
num_frames = lms.shape[0]
h, w = args.img_h, args.img_w
cxy = torch.tensor((w/2.0, h/2.0), dtype=torch.float).cuda()
id_dim, exp_dim, tex_dim, point_num = 100, 79, 100, 34650
model_3dmm = Face_3DMM(os.path.join(dir_path, '3DMM'),
                       id_dim, exp_dim, tex_dim, point_num)
lands_info = np.loadtxt(os.path.join(
    dir_path, '3DMM', 'lands_info.txt'), dtype=np.int32)
lands_info = torch.as_tensor(lands_info).cuda()
# mesh = openmesh.read_trimesh(os.path.join(dir_path, '3DMM', 'template.obj'))
focal = 1150

id_para = lms.new_zeros((1, id_dim), requires_grad=True)
exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
tex_para = lms.new_zeros((1, tex_dim), requires_grad=True)
euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
trans = lms.new_zeros((num_frames, 3), requires_grad=True)
light_para = lms.new_zeros((num_frames, 27), requires_grad=True)
trans.data[:, 2] -= 600
focal_length = lms.new_zeros(1, requires_grad=True)
focal_length.data += focal

set_requires_grad([id_para, exp_para, tex_para,
                   euler_angle, trans, light_para])

sel_ids = np.arange(0, num_frames, 10)
sel_num = sel_ids.shape[0]
arg_focal = 0.0
arg_landis = 1e5
for focal in range(500, 1500, 50):
    id_para = lms.new_zeros((1, id_dim), requires_grad=True)
    exp_para = lms.new_zeros((sel_num, exp_dim), requires_grad=True)
    euler_angle = lms.new_zeros((sel_num, 3), requires_grad=True)
    trans = lms.new_zeros((sel_num, 3), requires_grad=True)
    trans.data[:, 2] -= 600
    focal_length = lms.new_zeros(1, requires_grad=False)
    focal_length.data += focal
    set_requires_grad([id_para, exp_para, euler_angle, trans])

    optimizer_id = torch.optim.Adam([id_para], lr=.3)
    optimizer_exp = torch.optim.Adam([exp_para], lr=.3)
    optimizer_frame = torch.optim.Adam(
        [euler_angle, trans], lr=.3)
    iter_num = 2000

    for iter in range(iter_num):
        id_para_batch = id_para.expand(sel_num, -1)
        geometry = model_3dmm.forward_geo_sub(
            id_para_batch, exp_para, lands_info[-51:].long())
        proj_geo = forward_transform(
            geometry, euler_angle, trans, focal_length, cxy)
        loss_lan = cal_lan_loss(
            proj_geo[:, :, :2], lms[sel_ids, -51:, :].detach())
        loss_regid = torch.mean(id_para*id_para)*8
        loss_regexp = torch.mean(exp_para*exp_para)*0.5
        loss = loss_lan + loss_regid + loss_regexp
        optimizer_id.zero_grad()
        optimizer_exp.zero_grad()
        optimizer_frame.zero_grad()
        loss.backward()
        if iter > 1000:
            optimizer_id.step()
            optimizer_exp.step()
        optimizer_frame.step()
    print(focal, loss_lan.item(), torch.mean(trans[:, 2]).item())
    if loss_lan.item() < arg_landis:
        arg_landis = loss_lan.item()
        arg_focal = focal

sel_ids = np.arange(0, num_frames)
sel_num = sel_ids.shape[0]
id_para = lms.new_zeros((1, id_dim), requires_grad=True)
exp_para = lms.new_zeros((sel_num, exp_dim), requires_grad=True)
euler_angle = lms.new_zeros((sel_num, 3), requires_grad=True)
trans = lms.new_zeros((sel_num, 3), requires_grad=True)
trans.data[:, 2] -= 600
focal_length = lms.new_zeros(1, requires_grad=False)
focal_length.data += arg_focal
set_requires_grad([id_para, exp_para, euler_angle, trans])

optimizer_id = torch.optim.Adam([id_para], lr=.3)
optimizer_exp = torch.optim.Adam([exp_para], lr=.3)
optimizer_frame = torch.optim.Adam(
    [euler_angle, trans], lr=.3)
iter_num = 2000

for iter in range(iter_num):
    id_para_batch = id_para.expand(sel_num, -1)
    geometry = model_3dmm.forward_geo_sub(
        id_para_batch, exp_para, lands_info[-51:].long())
    proj_geo = forward_transform(
        geometry, euler_angle, trans, focal_length, cxy)
    loss_lan = cal_lan_loss(
        proj_geo[:, :, :2], lms[sel_ids, -51:, :].detach())
    loss_regid = torch.mean(id_para*id_para)*8
    loss_regexp = torch.mean(exp_para*exp_para)*0.5
    loss = loss_lan + loss_regid + loss_regexp
    optimizer_id.zero_grad()
    optimizer_exp.zero_grad()
    optimizer_frame.zero_grad()
    loss.backward()
    if iter > 1000:
        optimizer_id.step()
        optimizer_exp.step()
    optimizer_frame.step()
print(arg_focal, loss_lan.item(), torch.mean(trans[:, 2]).item())


torch.save({'id': id_para.detach().cpu(), 'exp': exp_para.detach().cpu(),
            'euler': euler_angle.detach().cpu(), 'trans': trans.detach().cpu(),
            'focal': focal_length.detach().cpu()}, os.path.join(os.path.dirname(args.path), 'track_params.pt'))
print('face tracking params saved')