Spaces:
Running
on
Zero
Running
on
Zero
Upload faster_live_portrait_pipeline.py
Browse files
difpoint/src/pipelines/faster_live_portrait_pipeline.py
CHANGED
@@ -1,455 +1,455 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# @Author : wenshao
|
3 |
-
# @Email : [email protected]
|
4 |
-
# @Project : FasterLivePortrait
|
5 |
-
# @FileName: faster_live_portrait_pipeline.py
|
6 |
-
|
7 |
-
import copy
|
8 |
-
import pdb
|
9 |
-
import time
|
10 |
-
import traceback
|
11 |
-
from PIL import Image
|
12 |
-
import cv2
|
13 |
-
from tqdm import tqdm
|
14 |
-
import numpy as np
|
15 |
-
import torch
|
16 |
-
|
17 |
-
from .. import models
|
18 |
-
from ..utils.crop import crop_image, parse_bbox_from_landmark, crop_image_by_bbox, paste_back, paste_back_pytorch
|
19 |
-
from ..utils.utils import resize_to_limit, prepare_paste_back, get_rotation_matrix, calc_lip_close_ratio, \
|
20 |
-
calc_eye_close_ratio, transform_keypoint, concat_feat
|
21 |
-
from difpoint.src.utils import utils
|
22 |
-
|
23 |
-
|
24 |
-
class FasterLivePortraitPipeline:
|
25 |
-
def __init__(self, cfg, **kwargs):
|
26 |
-
self.cfg = cfg
|
27 |
-
self.init(**kwargs)
|
28 |
-
|
29 |
-
def init(self, **kwargs):
|
30 |
-
self.init_vars(**kwargs)
|
31 |
-
self.init_models(**kwargs)
|
32 |
-
|
33 |
-
def clean_models(self, **kwargs):
|
34 |
-
"""
|
35 |
-
clean model
|
36 |
-
:param kwargs:
|
37 |
-
:return:
|
38 |
-
"""
|
39 |
-
for key in list(self.model_dict.keys()):
|
40 |
-
del self.model_dict[key]
|
41 |
-
self.model_dict = {}
|
42 |
-
|
43 |
-
def init_models(self, **kwargs):
|
44 |
-
if not kwargs.get("is_animal", False):
|
45 |
-
print("load Human Model >>>")
|
46 |
-
self.is_animal = False
|
47 |
-
self.model_dict = {}
|
48 |
-
for model_name in self.cfg.models:
|
49 |
-
print(f"loading model: {model_name}")
|
50 |
-
print(self.cfg.models[model_name])
|
51 |
-
self.model_dict[model_name] = getattr(models, self.cfg.models[model_name]["name"])(
|
52 |
-
**self.cfg.models[model_name])
|
53 |
-
else:
|
54 |
-
print("load Animal Model >>>")
|
55 |
-
self.is_animal = True
|
56 |
-
self.model_dict = {}
|
57 |
-
from src.utils.animal_landmark_runner import XPoseRunner
|
58 |
-
from src.utils.utils import make_abs_path
|
59 |
-
|
60 |
-
xpose_ckpt_path: str = make_abs_path("../difpoint/checkpoints/liveportrait_animal_onnx/xpose.pth")
|
61 |
-
xpose_config_file_path: str = make_abs_path("models/XPose/config_model/UniPose_SwinT.py")
|
62 |
-
xpose_embedding_cache_path: str = make_abs_path('../difpoint/checkpoints/liveportrait_animal_onnx/clip_embedding')
|
63 |
-
self.model_dict["xpose"] = XPoseRunner(model_config_path=xpose_config_file_path,
|
64 |
-
model_checkpoint_path=xpose_ckpt_path,
|
65 |
-
embeddings_cache_path=xpose_embedding_cache_path,
|
66 |
-
flag_use_half_precision=True)
|
67 |
-
for model_name in self.cfg.animal_models:
|
68 |
-
print(f"loading model: {model_name}")
|
69 |
-
print(self.cfg.animal_models[model_name])
|
70 |
-
self.model_dict[model_name] = getattr(models, self.cfg.animal_models[model_name]["name"])(
|
71 |
-
**self.cfg.animal_models[model_name])
|
72 |
-
|
73 |
-
def init_vars(self, **kwargs):
|
74 |
-
self.mask_crop = cv2.imread(self.cfg.infer_params.mask_crop_path, cv2.IMREAD_COLOR)
|
75 |
-
self.frame_id = 0
|
76 |
-
self.src_lmk_pre = None
|
77 |
-
self.R_d_0 = None
|
78 |
-
self.x_d_0_info = None
|
79 |
-
self.R_d_smooth = utils.OneEuroFilter(4, 1)
|
80 |
-
self.exp_smooth = utils.OneEuroFilter(4, 1)
|
81 |
-
|
82 |
-
## 记录source的信息
|
83 |
-
self.source_path = None
|
84 |
-
self.src_infos = []
|
85 |
-
self.src_imgs = []
|
86 |
-
self.is_source_video = False
|
87 |
-
|
88 |
-
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
89 |
-
|
90 |
-
def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk):
|
91 |
-
c_s_eyes = calc_eye_close_ratio(source_lmk[None])
|
92 |
-
c_d_eyes_i = np.array(c_d_eyes_i).reshape(1, 1)
|
93 |
-
# [c_s,eyes, c_d,eyes,i]
|
94 |
-
combined_eye_ratio_tensor = np.concatenate([c_s_eyes, c_d_eyes_i], axis=1)
|
95 |
-
return combined_eye_ratio_tensor
|
96 |
-
|
97 |
-
def calc_combined_lip_ratio(self, c_d_lip_i, source_lmk):
|
98 |
-
c_s_lip = calc_lip_close_ratio(source_lmk[None])
|
99 |
-
c_d_lip_i = np.array(c_d_lip_i).reshape(1, 1) # 1x1
|
100 |
-
# [c_s,lip, c_d,lip,i]
|
101 |
-
combined_lip_ratio_tensor = np.concatenate([c_s_lip, c_d_lip_i], axis=1) # 1x2
|
102 |
-
return combined_lip_ratio_tensor
|
103 |
-
|
104 |
-
def prepare_source(self, source_path, **kwargs):
|
105 |
-
print(f"process source:{source_path} >>>>>>>>")
|
106 |
-
try:
|
107 |
-
if utils.is_image(source_path):
|
108 |
-
self.is_source_video = False
|
109 |
-
elif utils.is_video(source_path):
|
110 |
-
self.is_source_video = True
|
111 |
-
else: # source input is an unknown format
|
112 |
-
raise Exception(f"Unknown source format: {source_path}")
|
113 |
-
|
114 |
-
if self.is_source_video:
|
115 |
-
src_imgs_bgr = []
|
116 |
-
src_vcap = cv2.VideoCapture(source_path)
|
117 |
-
while True:
|
118 |
-
ret, frame = src_vcap.read()
|
119 |
-
if not ret:
|
120 |
-
break
|
121 |
-
src_imgs_bgr.append(frame)
|
122 |
-
src_vcap.release()
|
123 |
-
else:
|
124 |
-
img_bgr = cv2.imread(source_path, cv2.IMREAD_COLOR)
|
125 |
-
src_imgs_bgr = [img_bgr]
|
126 |
-
|
127 |
-
self.src_imgs = []
|
128 |
-
self.src_infos = []
|
129 |
-
self.source_path = source_path
|
130 |
-
|
131 |
-
for ii, img_bgr in tqdm(enumerate(src_imgs_bgr), total=len(src_imgs_bgr)):
|
132 |
-
img_bgr = resize_to_limit(img_bgr, self.cfg.infer_params.source_max_dim,
|
133 |
-
self.cfg.infer_params.source_division)
|
134 |
-
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
135 |
-
src_faces = []
|
136 |
-
if self.is_animal:
|
137 |
-
with torch.no_grad():
|
138 |
-
img_rgb_pil = Image.fromarray(img_rgb)
|
139 |
-
lmk = self.model_dict["xpose"].run(
|
140 |
-
img_rgb_pil,
|
141 |
-
'face',
|
142 |
-
'animal_face',
|
143 |
-
0,
|
144 |
-
0
|
145 |
-
)
|
146 |
-
if lmk is None:
|
147 |
-
continue
|
148 |
-
self.src_imgs.append(img_rgb)
|
149 |
-
src_faces.append(lmk)
|
150 |
-
else:
|
151 |
-
src_faces = self.model_dict["face_analysis"].predict(img_bgr)
|
152 |
-
if len(src_faces) == 0:
|
153 |
-
print("No face detected in the this image.")
|
154 |
-
continue
|
155 |
-
self.src_imgs.append(img_rgb)
|
156 |
-
# 如果是实时,只关注最大的那张脸
|
157 |
-
if kwargs.get("realtime", False):
|
158 |
-
src_faces = src_faces[:1]
|
159 |
-
|
160 |
-
crop_infos = []
|
161 |
-
for i in range(len(src_faces)):
|
162 |
-
# NOTE: temporarily only pick the first face, to support multiple face in the future
|
163 |
-
lmk = src_faces[i]
|
164 |
-
# crop the face
|
165 |
-
ret_dct = crop_image(
|
166 |
-
img_rgb, # ndarray
|
167 |
-
lmk, # 106x2 or Nx2
|
168 |
-
dsize=self.cfg.crop_params.src_dsize,
|
169 |
-
scale=self.cfg.crop_params.src_scale,
|
170 |
-
vx_ratio=self.cfg.crop_params.src_vx_ratio,
|
171 |
-
vy_ratio=self.cfg.crop_params.src_vy_ratio,
|
172 |
-
)
|
173 |
-
if self.is_animal:
|
174 |
-
ret_dct["lmk_crop"] = lmk
|
175 |
-
else:
|
176 |
-
lmk = self.model_dict["landmark"].predict(img_rgb, lmk)
|
177 |
-
ret_dct["lmk_crop"] = lmk
|
178 |
-
ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / self.cfg.crop_params.src_dsize
|
179 |
-
|
180 |
-
# update a 256x256 version for network input
|
181 |
-
ret_dct["img_crop_256x256"] = cv2.resize(
|
182 |
-
ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA
|
183 |
-
)
|
184 |
-
crop_infos.append(ret_dct)
|
185 |
-
|
186 |
-
src_infos = [[] for _ in range(len(crop_infos))]
|
187 |
-
for i, crop_info in enumerate(crop_infos):
|
188 |
-
source_lmk = crop_info['lmk_crop']
|
189 |
-
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
|
190 |
-
pitch, yaw, roll, t, exp, scale, kp = self.model_dict["motion_extractor"].predict(
|
191 |
-
img_crop_256x256)
|
192 |
-
x_s_info = {
|
193 |
-
"pitch": pitch,
|
194 |
-
"yaw": yaw,
|
195 |
-
"roll": roll,
|
196 |
-
"t": t,
|
197 |
-
"exp": exp,
|
198 |
-
"scale": scale,
|
199 |
-
"kp": kp
|
200 |
-
}
|
201 |
-
src_infos[i].append(copy.deepcopy(x_s_info))
|
202 |
-
x_c_s = kp
|
203 |
-
R_s = get_rotation_matrix(pitch, yaw, roll)
|
204 |
-
f_s = self.model_dict["app_feat_extractor"].predict(img_crop_256x256)
|
205 |
-
x_s = transform_keypoint(pitch, yaw, roll, t, exp, scale, kp)
|
206 |
-
src_infos[i].extend([source_lmk.copy(), R_s.copy(), f_s.copy(), x_s.copy(), x_c_s.copy()])
|
207 |
-
if not self.is_animal:
|
208 |
-
flag_lip_zero = self.cfg.infer_params.flag_normalize_lip # not overwrite
|
209 |
-
if flag_lip_zero:
|
210 |
-
# let lip-open scalar to be 0 at first
|
211 |
-
c_d_lip_before_animation = [0.]
|
212 |
-
combined_lip_ratio_tensor_before_animation = self.calc_combined_lip_ratio(
|
213 |
-
c_d_lip_before_animation, source_lmk)
|
214 |
-
if combined_lip_ratio_tensor_before_animation[0][
|
215 |
-
0] < self.cfg.infer_params.lip_normalize_threshold:
|
216 |
-
flag_lip_zero = False
|
217 |
-
src_infos[i].append(None)
|
218 |
-
src_infos[i].append(flag_lip_zero)
|
219 |
-
else:
|
220 |
-
lip_delta_before_animation = self.model_dict['stitching_lip_retarget'].predict(
|
221 |
-
concat_feat(x_s, combined_lip_ratio_tensor_before_animation))
|
222 |
-
src_infos[i].append(lip_delta_before_animation.copy())
|
223 |
-
src_infos[i].append(flag_lip_zero)
|
224 |
-
else:
|
225 |
-
src_infos[i].append(None)
|
226 |
-
src_infos[i].append(flag_lip_zero)
|
227 |
-
else:
|
228 |
-
src_infos[i].append(None)
|
229 |
-
src_infos[i].append(False)
|
230 |
-
|
231 |
-
######## prepare for pasteback ########
|
232 |
-
if self.cfg.infer_params.flag_pasteback and self.cfg.infer_params.flag_do_crop and self.cfg.infer_params.flag_stitching:
|
233 |
-
mask_ori_float = prepare_paste_back(self.mask_crop, crop_info['M_c2o'],
|
234 |
-
dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
235 |
-
mask_ori_float = torch.from_numpy(mask_ori_float).to(self.device)
|
236 |
-
src_infos[i].append(mask_ori_float)
|
237 |
-
else:
|
238 |
-
src_infos[i].append(None)
|
239 |
-
M = torch.from_numpy(crop_info['M_c2o']).to(self.device)
|
240 |
-
src_infos[i].append(M)
|
241 |
-
self.src_infos.append(src_infos[:])
|
242 |
-
print(f"finish process source:{source_path} >>>>>>>>")
|
243 |
-
return len(self.src_infos) > 0
|
244 |
-
except Exception as e:
|
245 |
-
traceback.print_exc()
|
246 |
-
return False
|
247 |
-
|
248 |
-
def retarget_eye(self, kp_source, eye_close_ratio):
|
249 |
-
"""
|
250 |
-
kp_source: BxNx3
|
251 |
-
eye_close_ratio: Bx3
|
252 |
-
Return: Bx(3*num_kp+2)
|
253 |
-
"""
|
254 |
-
feat_eye = concat_feat(kp_source, eye_close_ratio)
|
255 |
-
delta = self.model_dict['stitching_eye_retarget'].predict(feat_eye)
|
256 |
-
return delta
|
257 |
-
|
258 |
-
def retarget_lip(self, kp_source, lip_close_ratio):
|
259 |
-
"""
|
260 |
-
kp_source: BxNx3
|
261 |
-
lip_close_ratio: Bx2
|
262 |
-
"""
|
263 |
-
feat_lip = concat_feat(kp_source, lip_close_ratio)
|
264 |
-
delta = self.model_dict['stitching_lip_retarget'].predict(feat_lip)
|
265 |
-
return delta
|
266 |
-
|
267 |
-
def stitching(self, kp_source, kp_driving):
|
268 |
-
""" conduct the stitching
|
269 |
-
kp_source: Bxnum_kpx3
|
270 |
-
kp_driving: Bxnum_kpx3
|
271 |
-
"""
|
272 |
-
|
273 |
-
bs, num_kp = kp_source.shape[:2]
|
274 |
-
|
275 |
-
kp_driving_new = kp_driving.copy()
|
276 |
-
|
277 |
-
delta = self.model_dict['stitching'].predict(concat_feat(kp_source, kp_driving_new))
|
278 |
-
|
279 |
-
delta_exp = delta[..., :3 * num_kp].reshape(bs, num_kp, 3) # 1x20x3
|
280 |
-
delta_tx_ty = delta[..., 3 * num_kp:3 * num_kp + 2].reshape(bs, 1, 2) # 1x1x2
|
281 |
-
|
282 |
-
kp_driving_new += delta_exp
|
283 |
-
kp_driving_new[..., :2] += delta_tx_ty
|
284 |
-
|
285 |
-
return kp_driving_new
|
286 |
-
|
287 |
-
def run(self, image, img_src, src_info, **kwargs):
|
288 |
-
img_bgr = image
|
289 |
-
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
290 |
-
I_p_pstbk = torch.from_numpy(img_src).to(self.device).float()
|
291 |
-
realtime = kwargs.get("realtime", False)
|
292 |
-
|
293 |
-
if self.cfg.infer_params.flag_crop_driving_video:
|
294 |
-
if self.src_lmk_pre is None:
|
295 |
-
src_face = self.model_dict["face_analysis"].predict(img_bgr)
|
296 |
-
if len(src_face) == 0:
|
297 |
-
self.src_lmk_pre = None
|
298 |
-
return None, None, None
|
299 |
-
lmk = src_face[0]
|
300 |
-
lmk = self.model_dict["landmark"].predict(img_rgb, lmk)
|
301 |
-
self.src_lmk_pre = lmk.copy()
|
302 |
-
else:
|
303 |
-
lmk = self.model_dict["landmark"].predict(img_rgb, self.src_lmk_pre)
|
304 |
-
self.src_lmk_pre = lmk.copy()
|
305 |
-
|
306 |
-
ret_bbox = parse_bbox_from_landmark(
|
307 |
-
lmk,
|
308 |
-
scale=self.cfg.crop_params.dri_scale,
|
309 |
-
vx_ratio_crop_video=self.cfg.crop_params.dri_vx_ratio,
|
310 |
-
vy_ratio=self.cfg.crop_params.dri_vy_ratio,
|
311 |
-
)["bbox"]
|
312 |
-
global_bbox = [
|
313 |
-
ret_bbox[0, 0],
|
314 |
-
ret_bbox[0, 1],
|
315 |
-
ret_bbox[2, 0],
|
316 |
-
ret_bbox[2, 1],
|
317 |
-
]
|
318 |
-
ret_dct = crop_image_by_bbox(
|
319 |
-
img_rgb,
|
320 |
-
global_bbox,
|
321 |
-
lmk=lmk,
|
322 |
-
dsize=kwargs.get("dsize", 512),
|
323 |
-
flag_rot=False,
|
324 |
-
borderValue=(0, 0, 0),
|
325 |
-
)
|
326 |
-
lmk_crop = ret_dct["lmk_crop"]
|
327 |
-
img_crop = ret_dct["img_crop"]
|
328 |
-
img_crop = cv2.resize(img_crop, (256, 256))
|
329 |
-
else:
|
330 |
-
if self.src_lmk_pre is None:
|
331 |
-
src_face = self.model_dict["face_analysis"].predict(img_bgr)
|
332 |
-
if len(src_face) == 0:
|
333 |
-
self.src_lmk_pre = None
|
334 |
-
return None, None, None
|
335 |
-
lmk = src_face[0]
|
336 |
-
lmk = self.model_dict["landmark"].predict(img_rgb, lmk)
|
337 |
-
self.src_lmk_pre = lmk.copy()
|
338 |
-
else:
|
339 |
-
lmk = self.model_dict["landmark"].predict(img_rgb, self.src_lmk_pre)
|
340 |
-
self.src_lmk_pre = lmk.copy()
|
341 |
-
lmk_crop = lmk.copy()
|
342 |
-
img_crop = cv2.resize(img_rgb, (256, 256))
|
343 |
-
|
344 |
-
input_eye_ratio = calc_eye_close_ratio(lmk_crop[None])
|
345 |
-
input_lip_ratio = calc_lip_close_ratio(lmk_crop[None])
|
346 |
-
pitch, yaw, roll, t, exp, scale, kp = self.model_dict["motion_extractor"].predict(img_crop)
|
347 |
-
x_d_i_info = {
|
348 |
-
"pitch": pitch,
|
349 |
-
"yaw": yaw,
|
350 |
-
"roll": roll,
|
351 |
-
"t": t,
|
352 |
-
"exp": exp,
|
353 |
-
"scale": scale,
|
354 |
-
"kp": kp
|
355 |
-
}
|
356 |
-
R_d_i = get_rotation_matrix(pitch, yaw, roll)
|
357 |
-
|
358 |
-
if kwargs.get("first_frame", False) or self.R_d_0 is None:
|
359 |
-
self.R_d_0 = R_d_i.copy()
|
360 |
-
self.x_d_0_info = copy.deepcopy(x_d_i_info)
|
361 |
-
# realtime smooth
|
362 |
-
self.R_d_smooth = utils.OneEuroFilter(4, 1)
|
363 |
-
self.exp_smooth = utils.OneEuroFilter(4, 1)
|
364 |
-
R_d_0 = self.R_d_0.copy()
|
365 |
-
x_d_0_info = copy.deepcopy(self.x_d_0_info)
|
366 |
-
out_crop, out_org = None, None
|
367 |
-
for j in range(len(src_info)):
|
368 |
-
x_s_info, source_lmk, R_s, f_s, x_s, x_c_s, lip_delta_before_animation, flag_lip_zero, mask_ori_float, M = \
|
369 |
-
src_info[j]
|
370 |
-
if self.cfg.infer_params.flag_relative_motion:
|
371 |
-
if self.is_source_video:
|
372 |
-
if self.cfg.infer_params.flag_video_editing_head_rotation:
|
373 |
-
R_new = (R_d_i @ np.transpose(R_d_0, (0, 2, 1))) @ R_s
|
374 |
-
R_new = self.R_d_smooth.process(R_new)
|
375 |
-
else:
|
376 |
-
R_new = R_s
|
377 |
-
else:
|
378 |
-
R_new = (R_d_i @ np.transpose(R_d_0, (0, 2, 1))) @ R_s
|
379 |
-
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
380 |
-
if self.is_source_video:
|
381 |
-
delta_new = self.exp_smooth.process(delta_new)
|
382 |
-
scale_new = x_s_info['scale'] if self.is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
383 |
-
t_new = x_s_info['t'] if self.is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
|
384 |
-
else:
|
385 |
-
if self.is_source_video:
|
386 |
-
if self.cfg.infer_params.flag_video_editing_head_rotation:
|
387 |
-
R_new = R_d_i
|
388 |
-
R_new = self.R_d_smooth.process(R_new)
|
389 |
-
else:
|
390 |
-
R_new = R_s
|
391 |
-
else:
|
392 |
-
R_new = R_d_i
|
393 |
-
delta_new = x_d_i_info['exp'].copy()
|
394 |
-
if self.is_source_video:
|
395 |
-
delta_new = self.exp_smooth.process(delta_new)
|
396 |
-
scale_new = x_s_info['scale'].copy()
|
397 |
-
t_new = x_d_i_info['t'].copy()
|
398 |
-
|
399 |
-
t_new[..., 2] = 0 # zero tz
|
400 |
-
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
401 |
-
if not self.is_animal:
|
402 |
-
# Algorithm 1:
|
403 |
-
if not self.cfg.infer_params.flag_stitching and not self.cfg.infer_params.flag_eye_retargeting and not self.cfg.infer_params.flag_lip_retargeting:
|
404 |
-
# without stitching or retargeting
|
405 |
-
if flag_lip_zero:
|
406 |
-
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
|
407 |
-
else:
|
408 |
-
pass
|
409 |
-
elif self.cfg.infer_params.flag_stitching and not self.cfg.infer_params.flag_eye_retargeting and not self.cfg.infer_params.flag_lip_retargeting:
|
410 |
-
# with stitching and without retargeting
|
411 |
-
if flag_lip_zero:
|
412 |
-
x_d_i_new = self.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(
|
413 |
-
-1, x_s.shape[1], 3)
|
414 |
-
else:
|
415 |
-
x_d_i_new = self.stitching(x_s, x_d_i_new)
|
416 |
-
else:
|
417 |
-
eyes_delta, lip_delta = None, None
|
418 |
-
if self.cfg.infer_params.flag_eye_retargeting:
|
419 |
-
c_d_eyes_i = input_eye_ratio
|
420 |
-
combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i,
|
421 |
-
source_lmk)
|
422 |
-
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
423 |
-
eyes_delta = self.retarget_eye(x_s, combined_eye_ratio_tensor)
|
424 |
-
if self.cfg.infer_params.flag_lip_retargeting:
|
425 |
-
c_d_lip_i = input_lip_ratio
|
426 |
-
combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
|
427 |
-
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
428 |
-
lip_delta = self.retarget_lip(x_s, combined_lip_ratio_tensor)
|
429 |
-
|
430 |
-
if self.cfg.infer_params.flag_relative_motion: # use x_s
|
431 |
-
x_d_i_new = x_s + \
|
432 |
-
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
433 |
-
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
434 |
-
else: # use x_d,i
|
435 |
-
x_d_i_new = x_d_i_new + \
|
436 |
-
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
437 |
-
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
438 |
-
|
439 |
-
if self.cfg.infer_params.flag_stitching:
|
440 |
-
x_d_i_new = self.stitching(x_s, x_d_i_new)
|
441 |
-
else:
|
442 |
-
if self.cfg.infer_params.flag_stitching:
|
443 |
-
x_d_i_new = self.stitching(x_s, x_d_i_new)
|
444 |
-
|
445 |
-
x_d_i_new = x_s + (x_d_i_new - x_s) * self.cfg.infer_params.driving_multiplier
|
446 |
-
out_crop = self.model_dict["warping_spade"].predict(f_s, x_s, x_d_i_new)
|
447 |
-
if not realtime and self.cfg.infer_params.flag_pasteback and self.cfg.infer_params.flag_do_crop and self.cfg.infer_params.flag_stitching:
|
448 |
-
# TODO: pasteback is slow, considering optimize it using multi-threading or GPU
|
449 |
-
# I_p_pstbk = paste_back(out_crop, crop_info['M_c2o'], I_p_pstbk, mask_ori_float)
|
450 |
-
I_p_pstbk = paste_back_pytorch(out_crop, M, I_p_pstbk, mask_ori_float)
|
451 |
-
|
452 |
-
return img_crop, out_crop.to(dtype=torch.uint8).cpu().numpy(), I_p_pstbk.to(dtype=torch.uint8).cpu().numpy()
|
453 |
-
|
454 |
-
def __del__(self):
|
455 |
-
self.clean_models()
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Author : wenshao
|
3 |
+
# @Email : [email protected]
|
4 |
+
# @Project : FasterLivePortrait
|
5 |
+
# @FileName: faster_live_portrait_pipeline.py
|
6 |
+
|
7 |
+
import copy
|
8 |
+
import pdb
|
9 |
+
import time
|
10 |
+
import traceback
|
11 |
+
from PIL import Image
|
12 |
+
import cv2
|
13 |
+
from tqdm import tqdm
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from .. import models
|
18 |
+
from ..utils.crop import crop_image, parse_bbox_from_landmark, crop_image_by_bbox, paste_back, paste_back_pytorch
|
19 |
+
from ..utils.utils import resize_to_limit, prepare_paste_back, get_rotation_matrix, calc_lip_close_ratio, \
|
20 |
+
calc_eye_close_ratio, transform_keypoint, concat_feat
|
21 |
+
from difpoint.src.utils import utils
|
22 |
+
|
23 |
+
|
24 |
+
class FasterLivePortraitPipeline:
|
25 |
+
def __init__(self, cfg, **kwargs):
|
26 |
+
self.cfg = cfg
|
27 |
+
self.init(**kwargs)
|
28 |
+
|
29 |
+
def init(self, **kwargs):
|
30 |
+
self.init_vars(**kwargs)
|
31 |
+
self.init_models(**kwargs)
|
32 |
+
|
33 |
+
def clean_models(self, **kwargs):
|
34 |
+
"""
|
35 |
+
clean model
|
36 |
+
:param kwargs:
|
37 |
+
:return:
|
38 |
+
"""
|
39 |
+
for key in list(self.model_dict.keys()):
|
40 |
+
del self.model_dict[key]
|
41 |
+
self.model_dict = {}
|
42 |
+
|
43 |
+
def init_models(self, **kwargs):
|
44 |
+
if not kwargs.get("is_animal", False):
|
45 |
+
print("load Human Model >>>")
|
46 |
+
self.is_animal = False
|
47 |
+
self.model_dict = {}
|
48 |
+
for model_name in self.cfg.models:
|
49 |
+
print(f"loading model: {model_name}")
|
50 |
+
print(self.cfg.models[model_name])
|
51 |
+
self.model_dict[model_name] = getattr(models, self.cfg.models[model_name]["name"])(
|
52 |
+
**self.cfg.models[model_name])
|
53 |
+
else:
|
54 |
+
print("load Animal Model >>>")
|
55 |
+
self.is_animal = True
|
56 |
+
self.model_dict = {}
|
57 |
+
from src.utils.animal_landmark_runner import XPoseRunner
|
58 |
+
from src.utils.utils import make_abs_path
|
59 |
+
|
60 |
+
xpose_ckpt_path: str = make_abs_path("../difpoint/checkpoints/liveportrait_animal_onnx/xpose.pth")
|
61 |
+
xpose_config_file_path: str = make_abs_path("models/XPose/config_model/UniPose_SwinT.py")
|
62 |
+
xpose_embedding_cache_path: str = make_abs_path('../difpoint/checkpoints/liveportrait_animal_onnx/clip_embedding')
|
63 |
+
self.model_dict["xpose"] = XPoseRunner(model_config_path=xpose_config_file_path,
|
64 |
+
model_checkpoint_path=xpose_ckpt_path,
|
65 |
+
embeddings_cache_path=xpose_embedding_cache_path,
|
66 |
+
flag_use_half_precision=True)
|
67 |
+
for model_name in self.cfg.animal_models:
|
68 |
+
print(f"loading model: {model_name}")
|
69 |
+
print(self.cfg.animal_models[model_name])
|
70 |
+
self.model_dict[model_name] = getattr(models, self.cfg.animal_models[model_name]["name"])(
|
71 |
+
**self.cfg.animal_models[model_name])
|
72 |
+
|
73 |
+
def init_vars(self, **kwargs):
|
74 |
+
self.mask_crop = cv2.imread(self.cfg.infer_params.mask_crop_path, cv2.IMREAD_COLOR)
|
75 |
+
self.frame_id = 0
|
76 |
+
self.src_lmk_pre = None
|
77 |
+
self.R_d_0 = None
|
78 |
+
self.x_d_0_info = None
|
79 |
+
self.R_d_smooth = utils.OneEuroFilter(4, 1)
|
80 |
+
self.exp_smooth = utils.OneEuroFilter(4, 1)
|
81 |
+
|
82 |
+
## 记录source的信息
|
83 |
+
self.source_path = None
|
84 |
+
self.src_infos = []
|
85 |
+
self.src_imgs = []
|
86 |
+
self.is_source_video = False
|
87 |
+
|
88 |
+
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
89 |
+
|
90 |
+
def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk):
|
91 |
+
c_s_eyes = calc_eye_close_ratio(source_lmk[None])
|
92 |
+
c_d_eyes_i = np.array(c_d_eyes_i).reshape(1, 1)
|
93 |
+
# [c_s,eyes, c_d,eyes,i]
|
94 |
+
combined_eye_ratio_tensor = np.concatenate([c_s_eyes, c_d_eyes_i], axis=1)
|
95 |
+
return combined_eye_ratio_tensor
|
96 |
+
|
97 |
+
def calc_combined_lip_ratio(self, c_d_lip_i, source_lmk):
|
98 |
+
c_s_lip = calc_lip_close_ratio(source_lmk[None])
|
99 |
+
c_d_lip_i = np.array(c_d_lip_i).reshape(1, 1) # 1x1
|
100 |
+
# [c_s,lip, c_d,lip,i]
|
101 |
+
combined_lip_ratio_tensor = np.concatenate([c_s_lip, c_d_lip_i], axis=1) # 1x2
|
102 |
+
return combined_lip_ratio_tensor
|
103 |
+
|
104 |
+
def prepare_source(self, source_path, **kwargs):
|
105 |
+
print(f"process source:{source_path} >>>>>>>>")
|
106 |
+
try:
|
107 |
+
if utils.is_image(source_path):
|
108 |
+
self.is_source_video = False
|
109 |
+
elif utils.is_video(source_path):
|
110 |
+
self.is_source_video = True
|
111 |
+
else: # source input is an unknown format
|
112 |
+
raise Exception(f"Unknown source format: {source_path}")
|
113 |
+
|
114 |
+
if self.is_source_video:
|
115 |
+
src_imgs_bgr = []
|
116 |
+
src_vcap = cv2.VideoCapture(source_path)
|
117 |
+
while True:
|
118 |
+
ret, frame = src_vcap.read()
|
119 |
+
if not ret:
|
120 |
+
break
|
121 |
+
src_imgs_bgr.append(frame)
|
122 |
+
src_vcap.release()
|
123 |
+
else:
|
124 |
+
img_bgr = cv2.imread(source_path, cv2.IMREAD_COLOR)
|
125 |
+
src_imgs_bgr = [img_bgr]
|
126 |
+
|
127 |
+
self.src_imgs = []
|
128 |
+
self.src_infos = []
|
129 |
+
self.source_path = source_path
|
130 |
+
|
131 |
+
for ii, img_bgr in tqdm(enumerate(src_imgs_bgr), total=len(src_imgs_bgr)):
|
132 |
+
img_bgr = resize_to_limit(img_bgr, self.cfg.infer_params.source_max_dim,
|
133 |
+
self.cfg.infer_params.source_division)
|
134 |
+
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
135 |
+
src_faces = []
|
136 |
+
if self.is_animal:
|
137 |
+
with torch.no_grad():
|
138 |
+
img_rgb_pil = Image.fromarray(img_rgb)
|
139 |
+
lmk = self.model_dict["xpose"].run(
|
140 |
+
img_rgb_pil,
|
141 |
+
'face',
|
142 |
+
'animal_face',
|
143 |
+
0,
|
144 |
+
0
|
145 |
+
)
|
146 |
+
if lmk is None:
|
147 |
+
continue
|
148 |
+
self.src_imgs.append(img_rgb)
|
149 |
+
src_faces.append(lmk)
|
150 |
+
else:
|
151 |
+
src_faces = self.model_dict["face_analysis"].predict(img_bgr)
|
152 |
+
if len(src_faces) == 0:
|
153 |
+
print("No face detected in the this image.")
|
154 |
+
continue
|
155 |
+
self.src_imgs.append(img_rgb)
|
156 |
+
# 如果是实时,只关注最大的那张脸
|
157 |
+
if kwargs.get("realtime", False):
|
158 |
+
src_faces = src_faces[:1]
|
159 |
+
|
160 |
+
crop_infos = []
|
161 |
+
for i in range(len(src_faces)):
|
162 |
+
# NOTE: temporarily only pick the first face, to support multiple face in the future
|
163 |
+
lmk = src_faces[i]
|
164 |
+
# crop the face
|
165 |
+
ret_dct = crop_image(
|
166 |
+
img_rgb, # ndarray
|
167 |
+
lmk, # 106x2 or Nx2
|
168 |
+
dsize=self.cfg.crop_params.src_dsize,
|
169 |
+
scale=self.cfg.crop_params.src_scale,
|
170 |
+
vx_ratio=self.cfg.crop_params.src_vx_ratio,
|
171 |
+
vy_ratio=self.cfg.crop_params.src_vy_ratio,
|
172 |
+
)
|
173 |
+
if self.is_animal:
|
174 |
+
ret_dct["lmk_crop"] = lmk
|
175 |
+
else:
|
176 |
+
lmk = self.model_dict["landmark"].predict(img_rgb, lmk)
|
177 |
+
ret_dct["lmk_crop"] = lmk
|
178 |
+
ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / self.cfg.crop_params.src_dsize
|
179 |
+
|
180 |
+
# update a 256x256 version for network input
|
181 |
+
ret_dct["img_crop_256x256"] = cv2.resize(
|
182 |
+
ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA
|
183 |
+
)
|
184 |
+
crop_infos.append(ret_dct)
|
185 |
+
|
186 |
+
src_infos = [[] for _ in range(len(crop_infos))]
|
187 |
+
for i, crop_info in enumerate(crop_infos):
|
188 |
+
source_lmk = crop_info['lmk_crop']
|
189 |
+
img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
|
190 |
+
pitch, yaw, roll, t, exp, scale, kp = self.model_dict["motion_extractor"].predict(
|
191 |
+
img_crop_256x256)
|
192 |
+
x_s_info = {
|
193 |
+
"pitch": pitch,
|
194 |
+
"yaw": yaw,
|
195 |
+
"roll": roll,
|
196 |
+
"t": t,
|
197 |
+
"exp": exp,
|
198 |
+
"scale": scale,
|
199 |
+
"kp": kp
|
200 |
+
}
|
201 |
+
src_infos[i].append(copy.deepcopy(x_s_info))
|
202 |
+
x_c_s = kp
|
203 |
+
R_s = get_rotation_matrix(pitch, yaw, roll)
|
204 |
+
f_s = self.model_dict["app_feat_extractor"].predict(img_crop_256x256)
|
205 |
+
x_s = transform_keypoint(pitch, yaw, roll, t, exp, scale, kp)
|
206 |
+
src_infos[i].extend([source_lmk.copy(), R_s.copy(), f_s.copy(), x_s.copy(), x_c_s.copy()])
|
207 |
+
if not self.is_animal:
|
208 |
+
flag_lip_zero = self.cfg.infer_params.flag_normalize_lip # not overwrite
|
209 |
+
if flag_lip_zero:
|
210 |
+
# let lip-open scalar to be 0 at first
|
211 |
+
c_d_lip_before_animation = [0.]
|
212 |
+
combined_lip_ratio_tensor_before_animation = self.calc_combined_lip_ratio(
|
213 |
+
c_d_lip_before_animation, source_lmk)
|
214 |
+
if combined_lip_ratio_tensor_before_animation[0][
|
215 |
+
0] < self.cfg.infer_params.lip_normalize_threshold:
|
216 |
+
flag_lip_zero = False
|
217 |
+
src_infos[i].append(None)
|
218 |
+
src_infos[i].append(flag_lip_zero)
|
219 |
+
else:
|
220 |
+
lip_delta_before_animation = self.model_dict['stitching_lip_retarget'].predict(
|
221 |
+
concat_feat(x_s, combined_lip_ratio_tensor_before_animation))
|
222 |
+
src_infos[i].append(lip_delta_before_animation.copy())
|
223 |
+
src_infos[i].append(flag_lip_zero)
|
224 |
+
else:
|
225 |
+
src_infos[i].append(None)
|
226 |
+
src_infos[i].append(flag_lip_zero)
|
227 |
+
else:
|
228 |
+
src_infos[i].append(None)
|
229 |
+
src_infos[i].append(False)
|
230 |
+
|
231 |
+
######## prepare for pasteback ########
|
232 |
+
if self.cfg.infer_params.flag_pasteback and self.cfg.infer_params.flag_do_crop and self.cfg.infer_params.flag_stitching:
|
233 |
+
mask_ori_float = prepare_paste_back(self.mask_crop, crop_info['M_c2o'],
|
234 |
+
dsize=(img_rgb.shape[1], img_rgb.shape[0]))
|
235 |
+
mask_ori_float = torch.from_numpy(mask_ori_float).to(self.device)
|
236 |
+
src_infos[i].append(mask_ori_float)
|
237 |
+
else:
|
238 |
+
src_infos[i].append(None)
|
239 |
+
M = torch.from_numpy(crop_info['M_c2o']).to(self.device)
|
240 |
+
src_infos[i].append(M)
|
241 |
+
self.src_infos.append(src_infos[:])
|
242 |
+
print(f"finish process source:{source_path} >>>>>>>>")
|
243 |
+
return len(self.src_infos) > 0
|
244 |
+
except Exception as e:
|
245 |
+
traceback.print_exc()
|
246 |
+
return False
|
247 |
+
|
248 |
+
def retarget_eye(self, kp_source, eye_close_ratio):
|
249 |
+
"""
|
250 |
+
kp_source: BxNx3
|
251 |
+
eye_close_ratio: Bx3
|
252 |
+
Return: Bx(3*num_kp+2)
|
253 |
+
"""
|
254 |
+
feat_eye = concat_feat(kp_source, eye_close_ratio)
|
255 |
+
delta = self.model_dict['stitching_eye_retarget'].predict(feat_eye)
|
256 |
+
return delta
|
257 |
+
|
258 |
+
def retarget_lip(self, kp_source, lip_close_ratio):
|
259 |
+
"""
|
260 |
+
kp_source: BxNx3
|
261 |
+
lip_close_ratio: Bx2
|
262 |
+
"""
|
263 |
+
feat_lip = concat_feat(kp_source, lip_close_ratio)
|
264 |
+
delta = self.model_dict['stitching_lip_retarget'].predict(feat_lip)
|
265 |
+
return delta
|
266 |
+
|
267 |
+
def stitching(self, kp_source, kp_driving):
|
268 |
+
""" conduct the stitching
|
269 |
+
kp_source: Bxnum_kpx3
|
270 |
+
kp_driving: Bxnum_kpx3
|
271 |
+
"""
|
272 |
+
|
273 |
+
bs, num_kp = kp_source.shape[:2]
|
274 |
+
|
275 |
+
kp_driving_new = kp_driving.copy()
|
276 |
+
|
277 |
+
delta = self.model_dict['stitching'].predict(concat_feat(kp_source, kp_driving_new))
|
278 |
+
|
279 |
+
delta_exp = delta[..., :3 * num_kp].reshape(bs, num_kp, 3) # 1x20x3
|
280 |
+
delta_tx_ty = delta[..., 3 * num_kp:3 * num_kp + 2].reshape(bs, 1, 2) # 1x1x2
|
281 |
+
|
282 |
+
kp_driving_new += delta_exp
|
283 |
+
kp_driving_new[..., :2] += delta_tx_ty
|
284 |
+
|
285 |
+
return kp_driving_new
|
286 |
+
|
287 |
+
def run(self, image, img_src, src_info, **kwargs):
|
288 |
+
img_bgr = image
|
289 |
+
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
290 |
+
I_p_pstbk = torch.from_numpy(img_src).to(self.device).float()
|
291 |
+
realtime = kwargs.get("realtime", False)
|
292 |
+
|
293 |
+
if self.cfg.infer_params.flag_crop_driving_video:
|
294 |
+
if self.src_lmk_pre is None:
|
295 |
+
src_face = self.model_dict["face_analysis"].predict(img_bgr)
|
296 |
+
if len(src_face) == 0:
|
297 |
+
self.src_lmk_pre = None
|
298 |
+
return None, None, None
|
299 |
+
lmk = src_face[0]
|
300 |
+
lmk = self.model_dict["landmark"].predict(img_rgb, lmk)
|
301 |
+
self.src_lmk_pre = lmk.copy()
|
302 |
+
else:
|
303 |
+
lmk = self.model_dict["landmark"].predict(img_rgb, self.src_lmk_pre)
|
304 |
+
self.src_lmk_pre = lmk.copy()
|
305 |
+
|
306 |
+
ret_bbox = parse_bbox_from_landmark(
|
307 |
+
lmk,
|
308 |
+
scale=self.cfg.crop_params.dri_scale,
|
309 |
+
vx_ratio_crop_video=self.cfg.crop_params.dri_vx_ratio,
|
310 |
+
vy_ratio=self.cfg.crop_params.dri_vy_ratio,
|
311 |
+
)["bbox"]
|
312 |
+
global_bbox = [
|
313 |
+
ret_bbox[0, 0],
|
314 |
+
ret_bbox[0, 1],
|
315 |
+
ret_bbox[2, 0],
|
316 |
+
ret_bbox[2, 1],
|
317 |
+
]
|
318 |
+
ret_dct = crop_image_by_bbox(
|
319 |
+
img_rgb,
|
320 |
+
global_bbox,
|
321 |
+
lmk=lmk,
|
322 |
+
dsize=kwargs.get("dsize", 512),
|
323 |
+
flag_rot=False,
|
324 |
+
borderValue=(0, 0, 0),
|
325 |
+
)
|
326 |
+
lmk_crop = ret_dct["lmk_crop"]
|
327 |
+
img_crop = ret_dct["img_crop"]
|
328 |
+
img_crop = cv2.resize(img_crop, (256, 256))
|
329 |
+
else:
|
330 |
+
if self.src_lmk_pre is None:
|
331 |
+
src_face = self.model_dict["face_analysis"].predict(img_bgr)
|
332 |
+
if len(src_face) == 0:
|
333 |
+
self.src_lmk_pre = None
|
334 |
+
return None, None, None
|
335 |
+
lmk = src_face[0]
|
336 |
+
lmk = self.model_dict["landmark"].predict(img_rgb, lmk)
|
337 |
+
self.src_lmk_pre = lmk.copy()
|
338 |
+
else:
|
339 |
+
lmk = self.model_dict["landmark"].predict(img_rgb, self.src_lmk_pre)
|
340 |
+
self.src_lmk_pre = lmk.copy()
|
341 |
+
lmk_crop = lmk.copy()
|
342 |
+
img_crop = cv2.resize(img_rgb, (256, 256))
|
343 |
+
|
344 |
+
input_eye_ratio = calc_eye_close_ratio(lmk_crop[None])
|
345 |
+
input_lip_ratio = calc_lip_close_ratio(lmk_crop[None])
|
346 |
+
pitch, yaw, roll, t, exp, scale, kp = self.model_dict["motion_extractor"].predict(img_crop)
|
347 |
+
x_d_i_info = {
|
348 |
+
"pitch": pitch,
|
349 |
+
"yaw": yaw,
|
350 |
+
"roll": roll,
|
351 |
+
"t": t,
|
352 |
+
"exp": exp,
|
353 |
+
"scale": scale,
|
354 |
+
"kp": kp
|
355 |
+
}
|
356 |
+
R_d_i = get_rotation_matrix(pitch, yaw, roll)
|
357 |
+
|
358 |
+
if kwargs.get("first_frame", False) or self.R_d_0 is None:
|
359 |
+
self.R_d_0 = R_d_i.copy()
|
360 |
+
self.x_d_0_info = copy.deepcopy(x_d_i_info)
|
361 |
+
# realtime smooth
|
362 |
+
self.R_d_smooth = utils.OneEuroFilter(4, 1)
|
363 |
+
self.exp_smooth = utils.OneEuroFilter(4, 1)
|
364 |
+
R_d_0 = self.R_d_0.copy()
|
365 |
+
x_d_0_info = copy.deepcopy(self.x_d_0_info)
|
366 |
+
out_crop, out_org = None, None
|
367 |
+
for j in range(len(src_info)):
|
368 |
+
x_s_info, source_lmk, R_s, f_s, x_s, x_c_s, lip_delta_before_animation, flag_lip_zero, mask_ori_float, M = \
|
369 |
+
src_info[j]
|
370 |
+
if self.cfg.infer_params.flag_relative_motion:
|
371 |
+
if self.is_source_video:
|
372 |
+
if self.cfg.infer_params.flag_video_editing_head_rotation:
|
373 |
+
R_new = (R_d_i @ np.transpose(R_d_0, (0, 2, 1))) @ R_s
|
374 |
+
R_new = self.R_d_smooth.process(R_new)
|
375 |
+
else:
|
376 |
+
R_new = R_s
|
377 |
+
else:
|
378 |
+
R_new = (R_d_i @ np.transpose(R_d_0, (0, 2, 1))) @ R_s
|
379 |
+
delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
|
380 |
+
if self.is_source_video:
|
381 |
+
delta_new = self.exp_smooth.process(delta_new)
|
382 |
+
scale_new = x_s_info['scale'] if self.is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
|
383 |
+
t_new = x_s_info['t'] if self.is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
|
384 |
+
else:
|
385 |
+
if self.is_source_video:
|
386 |
+
if self.cfg.infer_params.flag_video_editing_head_rotation:
|
387 |
+
R_new = R_d_i
|
388 |
+
R_new = self.R_d_smooth.process(R_new)
|
389 |
+
else:
|
390 |
+
R_new = R_s
|
391 |
+
else:
|
392 |
+
R_new = R_d_i
|
393 |
+
delta_new = x_d_i_info['exp'].copy()
|
394 |
+
if self.is_source_video:
|
395 |
+
delta_new = self.exp_smooth.process(delta_new)
|
396 |
+
scale_new = x_s_info['scale'].copy()
|
397 |
+
t_new = x_d_i_info['t'].copy()
|
398 |
+
|
399 |
+
t_new[..., 2] = 0 # zero tz
|
400 |
+
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
|
401 |
+
if not self.is_animal:
|
402 |
+
# Algorithm 1:
|
403 |
+
if not self.cfg.infer_params.flag_stitching and not self.cfg.infer_params.flag_eye_retargeting and not self.cfg.infer_params.flag_lip_retargeting:
|
404 |
+
# without stitching or retargeting
|
405 |
+
if flag_lip_zero:
|
406 |
+
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
|
407 |
+
else:
|
408 |
+
pass
|
409 |
+
elif self.cfg.infer_params.flag_stitching and not self.cfg.infer_params.flag_eye_retargeting and not self.cfg.infer_params.flag_lip_retargeting:
|
410 |
+
# with stitching and without retargeting
|
411 |
+
if flag_lip_zero:
|
412 |
+
x_d_i_new = self.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(
|
413 |
+
-1, x_s.shape[1], 3)
|
414 |
+
else:
|
415 |
+
x_d_i_new = self.stitching(x_s, x_d_i_new)
|
416 |
+
else:
|
417 |
+
eyes_delta, lip_delta = None, None
|
418 |
+
if self.cfg.infer_params.flag_eye_retargeting:
|
419 |
+
c_d_eyes_i = input_eye_ratio
|
420 |
+
combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i,
|
421 |
+
source_lmk)
|
422 |
+
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
|
423 |
+
eyes_delta = self.retarget_eye(x_s, combined_eye_ratio_tensor)
|
424 |
+
if self.cfg.infer_params.flag_lip_retargeting:
|
425 |
+
c_d_lip_i = input_lip_ratio
|
426 |
+
combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
|
427 |
+
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
|
428 |
+
lip_delta = self.retarget_lip(x_s, combined_lip_ratio_tensor)
|
429 |
+
|
430 |
+
if self.cfg.infer_params.flag_relative_motion: # use x_s
|
431 |
+
x_d_i_new = x_s + \
|
432 |
+
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
433 |
+
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
434 |
+
else: # use x_d,i
|
435 |
+
x_d_i_new = x_d_i_new + \
|
436 |
+
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
|
437 |
+
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
438 |
+
|
439 |
+
if self.cfg.infer_params.flag_stitching:
|
440 |
+
x_d_i_new = self.stitching(x_s, x_d_i_new)
|
441 |
+
else:
|
442 |
+
if self.cfg.infer_params.flag_stitching:
|
443 |
+
x_d_i_new = self.stitching(x_s, x_d_i_new)
|
444 |
+
|
445 |
+
x_d_i_new = x_s + (x_d_i_new - x_s) * self.cfg.infer_params.driving_multiplier
|
446 |
+
out_crop = self.model_dict["warping_spade"].predict(f_s, x_s, x_d_i_new)
|
447 |
+
if not realtime and self.cfg.infer_params.flag_pasteback and self.cfg.infer_params.flag_do_crop and self.cfg.infer_params.flag_stitching:
|
448 |
+
# TODO: pasteback is slow, considering optimize it using multi-threading or GPU
|
449 |
+
# I_p_pstbk = paste_back(out_crop, crop_info['M_c2o'], I_p_pstbk, mask_ori_float)
|
450 |
+
I_p_pstbk = paste_back_pytorch(out_crop, M, I_p_pstbk, mask_ori_float)
|
451 |
+
|
452 |
+
return img_crop, out_crop.to(dtype=torch.uint8).cpu().numpy(), I_p_pstbk.to(dtype=torch.uint8).cpu().numpy()
|
453 |
+
|
454 |
+
def __del__(self):
|
455 |
+
self.clean_models()
|