YinuoGuo27 commited on
Commit
6635078
·
verified ·
1 Parent(s): fe29479

Upload faster_live_portrait_pipeline.py

Browse files
difpoint/src/pipelines/faster_live_portrait_pipeline.py CHANGED
@@ -1,455 +1,455 @@
1
- # -*- coding: utf-8 -*-
2
- # @Author : wenshao
3
- # @Email : [email protected]
4
- # @Project : FasterLivePortrait
5
- # @FileName: faster_live_portrait_pipeline.py
6
-
7
- import copy
8
- import pdb
9
- import time
10
- import traceback
11
- from PIL import Image
12
- import cv2
13
- from tqdm import tqdm
14
- import numpy as np
15
- import torch
16
-
17
- from .. import models
18
- from ..utils.crop import crop_image, parse_bbox_from_landmark, crop_image_by_bbox, paste_back, paste_back_pytorch
19
- from ..utils.utils import resize_to_limit, prepare_paste_back, get_rotation_matrix, calc_lip_close_ratio, \
20
- calc_eye_close_ratio, transform_keypoint, concat_feat
21
- from difpoint.src.utils import utils
22
-
23
-
24
- class FasterLivePortraitPipeline:
25
- def __init__(self, cfg, **kwargs):
26
- self.cfg = cfg
27
- self.init(**kwargs)
28
-
29
- def init(self, **kwargs):
30
- self.init_vars(**kwargs)
31
- self.init_models(**kwargs)
32
-
33
- def clean_models(self, **kwargs):
34
- """
35
- clean model
36
- :param kwargs:
37
- :return:
38
- """
39
- for key in list(self.model_dict.keys()):
40
- del self.model_dict[key]
41
- self.model_dict = {}
42
-
43
- def init_models(self, **kwargs):
44
- if not kwargs.get("is_animal", False):
45
- print("load Human Model >>>")
46
- self.is_animal = False
47
- self.model_dict = {}
48
- for model_name in self.cfg.models:
49
- print(f"loading model: {model_name}")
50
- print(self.cfg.models[model_name])
51
- self.model_dict[model_name] = getattr(models, self.cfg.models[model_name]["name"])(
52
- **self.cfg.models[model_name])
53
- else:
54
- print("load Animal Model >>>")
55
- self.is_animal = True
56
- self.model_dict = {}
57
- from src.utils.animal_landmark_runner import XPoseRunner
58
- from src.utils.utils import make_abs_path
59
-
60
- xpose_ckpt_path: str = make_abs_path("../difpoint/checkpoints/liveportrait_animal_onnx/xpose.pth")
61
- xpose_config_file_path: str = make_abs_path("models/XPose/config_model/UniPose_SwinT.py")
62
- xpose_embedding_cache_path: str = make_abs_path('../difpoint/checkpoints/liveportrait_animal_onnx/clip_embedding')
63
- self.model_dict["xpose"] = XPoseRunner(model_config_path=xpose_config_file_path,
64
- model_checkpoint_path=xpose_ckpt_path,
65
- embeddings_cache_path=xpose_embedding_cache_path,
66
- flag_use_half_precision=True)
67
- for model_name in self.cfg.animal_models:
68
- print(f"loading model: {model_name}")
69
- print(self.cfg.animal_models[model_name])
70
- self.model_dict[model_name] = getattr(models, self.cfg.animal_models[model_name]["name"])(
71
- **self.cfg.animal_models[model_name])
72
-
73
- def init_vars(self, **kwargs):
74
- self.mask_crop = cv2.imread(self.cfg.infer_params.mask_crop_path, cv2.IMREAD_COLOR)
75
- self.frame_id = 0
76
- self.src_lmk_pre = None
77
- self.R_d_0 = None
78
- self.x_d_0_info = None
79
- self.R_d_smooth = utils.OneEuroFilter(4, 1)
80
- self.exp_smooth = utils.OneEuroFilter(4, 1)
81
-
82
- ## 记录source的信息
83
- self.source_path = None
84
- self.src_infos = []
85
- self.src_imgs = []
86
- self.is_source_video = False
87
-
88
- self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
89
-
90
- def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk):
91
- c_s_eyes = calc_eye_close_ratio(source_lmk[None])
92
- c_d_eyes_i = np.array(c_d_eyes_i).reshape(1, 1)
93
- # [c_s,eyes, c_d,eyes,i]
94
- combined_eye_ratio_tensor = np.concatenate([c_s_eyes, c_d_eyes_i], axis=1)
95
- return combined_eye_ratio_tensor
96
-
97
- def calc_combined_lip_ratio(self, c_d_lip_i, source_lmk):
98
- c_s_lip = calc_lip_close_ratio(source_lmk[None])
99
- c_d_lip_i = np.array(c_d_lip_i).reshape(1, 1) # 1x1
100
- # [c_s,lip, c_d,lip,i]
101
- combined_lip_ratio_tensor = np.concatenate([c_s_lip, c_d_lip_i], axis=1) # 1x2
102
- return combined_lip_ratio_tensor
103
-
104
- def prepare_source(self, source_path, **kwargs):
105
- print(f"process source:{source_path} >>>>>>>>")
106
- try:
107
- if utils.is_image(source_path):
108
- self.is_source_video = False
109
- elif utils.is_video(source_path):
110
- self.is_source_video = True
111
- else: # source input is an unknown format
112
- raise Exception(f"Unknown source format: {source_path}")
113
-
114
- if self.is_source_video:
115
- src_imgs_bgr = []
116
- src_vcap = cv2.VideoCapture(source_path)
117
- while True:
118
- ret, frame = src_vcap.read()
119
- if not ret:
120
- break
121
- src_imgs_bgr.append(frame)
122
- src_vcap.release()
123
- else:
124
- img_bgr = cv2.imread(source_path, cv2.IMREAD_COLOR)
125
- src_imgs_bgr = [img_bgr]
126
-
127
- self.src_imgs = []
128
- self.src_infos = []
129
- self.source_path = source_path
130
-
131
- for ii, img_bgr in tqdm(enumerate(src_imgs_bgr), total=len(src_imgs_bgr)):
132
- img_bgr = resize_to_limit(img_bgr, self.cfg.infer_params.source_max_dim,
133
- self.cfg.infer_params.source_division)
134
- img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
135
- src_faces = []
136
- if self.is_animal:
137
- with torch.no_grad():
138
- img_rgb_pil = Image.fromarray(img_rgb)
139
- lmk = self.model_dict["xpose"].run(
140
- img_rgb_pil,
141
- 'face',
142
- 'animal_face',
143
- 0,
144
- 0
145
- )
146
- if lmk is None:
147
- continue
148
- self.src_imgs.append(img_rgb)
149
- src_faces.append(lmk)
150
- else:
151
- src_faces = self.model_dict["face_analysis"].predict(img_bgr)
152
- if len(src_faces) == 0:
153
- print("No face detected in the this image.")
154
- continue
155
- self.src_imgs.append(img_rgb)
156
- # 如果是实时,只关注最大的那张脸
157
- if kwargs.get("realtime", False):
158
- src_faces = src_faces[:1]
159
-
160
- crop_infos = []
161
- for i in range(len(src_faces)):
162
- # NOTE: temporarily only pick the first face, to support multiple face in the future
163
- lmk = src_faces[i]
164
- # crop the face
165
- ret_dct = crop_image(
166
- img_rgb, # ndarray
167
- lmk, # 106x2 or Nx2
168
- dsize=self.cfg.crop_params.src_dsize,
169
- scale=self.cfg.crop_params.src_scale,
170
- vx_ratio=self.cfg.crop_params.src_vx_ratio,
171
- vy_ratio=self.cfg.crop_params.src_vy_ratio,
172
- )
173
- if self.is_animal:
174
- ret_dct["lmk_crop"] = lmk
175
- else:
176
- lmk = self.model_dict["landmark"].predict(img_rgb, lmk)
177
- ret_dct["lmk_crop"] = lmk
178
- ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / self.cfg.crop_params.src_dsize
179
-
180
- # update a 256x256 version for network input
181
- ret_dct["img_crop_256x256"] = cv2.resize(
182
- ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA
183
- )
184
- crop_infos.append(ret_dct)
185
-
186
- src_infos = [[] for _ in range(len(crop_infos))]
187
- for i, crop_info in enumerate(crop_infos):
188
- source_lmk = crop_info['lmk_crop']
189
- img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
190
- pitch, yaw, roll, t, exp, scale, kp = self.model_dict["motion_extractor"].predict(
191
- img_crop_256x256)
192
- x_s_info = {
193
- "pitch": pitch,
194
- "yaw": yaw,
195
- "roll": roll,
196
- "t": t,
197
- "exp": exp,
198
- "scale": scale,
199
- "kp": kp
200
- }
201
- src_infos[i].append(copy.deepcopy(x_s_info))
202
- x_c_s = kp
203
- R_s = get_rotation_matrix(pitch, yaw, roll)
204
- f_s = self.model_dict["app_feat_extractor"].predict(img_crop_256x256)
205
- x_s = transform_keypoint(pitch, yaw, roll, t, exp, scale, kp)
206
- src_infos[i].extend([source_lmk.copy(), R_s.copy(), f_s.copy(), x_s.copy(), x_c_s.copy()])
207
- if not self.is_animal:
208
- flag_lip_zero = self.cfg.infer_params.flag_normalize_lip # not overwrite
209
- if flag_lip_zero:
210
- # let lip-open scalar to be 0 at first
211
- c_d_lip_before_animation = [0.]
212
- combined_lip_ratio_tensor_before_animation = self.calc_combined_lip_ratio(
213
- c_d_lip_before_animation, source_lmk)
214
- if combined_lip_ratio_tensor_before_animation[0][
215
- 0] < self.cfg.infer_params.lip_normalize_threshold:
216
- flag_lip_zero = False
217
- src_infos[i].append(None)
218
- src_infos[i].append(flag_lip_zero)
219
- else:
220
- lip_delta_before_animation = self.model_dict['stitching_lip_retarget'].predict(
221
- concat_feat(x_s, combined_lip_ratio_tensor_before_animation))
222
- src_infos[i].append(lip_delta_before_animation.copy())
223
- src_infos[i].append(flag_lip_zero)
224
- else:
225
- src_infos[i].append(None)
226
- src_infos[i].append(flag_lip_zero)
227
- else:
228
- src_infos[i].append(None)
229
- src_infos[i].append(False)
230
-
231
- ######## prepare for pasteback ########
232
- if self.cfg.infer_params.flag_pasteback and self.cfg.infer_params.flag_do_crop and self.cfg.infer_params.flag_stitching:
233
- mask_ori_float = prepare_paste_back(self.mask_crop, crop_info['M_c2o'],
234
- dsize=(img_rgb.shape[1], img_rgb.shape[0]))
235
- mask_ori_float = torch.from_numpy(mask_ori_float).to(self.device)
236
- src_infos[i].append(mask_ori_float)
237
- else:
238
- src_infos[i].append(None)
239
- M = torch.from_numpy(crop_info['M_c2o']).to(self.device)
240
- src_infos[i].append(M)
241
- self.src_infos.append(src_infos[:])
242
- print(f"finish process source:{source_path} >>>>>>>>")
243
- return len(self.src_infos) > 0
244
- except Exception as e:
245
- traceback.print_exc()
246
- return False
247
-
248
- def retarget_eye(self, kp_source, eye_close_ratio):
249
- """
250
- kp_source: BxNx3
251
- eye_close_ratio: Bx3
252
- Return: Bx(3*num_kp+2)
253
- """
254
- feat_eye = concat_feat(kp_source, eye_close_ratio)
255
- delta = self.model_dict['stitching_eye_retarget'].predict(feat_eye)
256
- return delta
257
-
258
- def retarget_lip(self, kp_source, lip_close_ratio):
259
- """
260
- kp_source: BxNx3
261
- lip_close_ratio: Bx2
262
- """
263
- feat_lip = concat_feat(kp_source, lip_close_ratio)
264
- delta = self.model_dict['stitching_lip_retarget'].predict(feat_lip)
265
- return delta
266
-
267
- def stitching(self, kp_source, kp_driving):
268
- """ conduct the stitching
269
- kp_source: Bxnum_kpx3
270
- kp_driving: Bxnum_kpx3
271
- """
272
-
273
- bs, num_kp = kp_source.shape[:2]
274
-
275
- kp_driving_new = kp_driving.copy()
276
-
277
- delta = self.model_dict['stitching'].predict(concat_feat(kp_source, kp_driving_new))
278
-
279
- delta_exp = delta[..., :3 * num_kp].reshape(bs, num_kp, 3) # 1x20x3
280
- delta_tx_ty = delta[..., 3 * num_kp:3 * num_kp + 2].reshape(bs, 1, 2) # 1x1x2
281
-
282
- kp_driving_new += delta_exp
283
- kp_driving_new[..., :2] += delta_tx_ty
284
-
285
- return kp_driving_new
286
-
287
- def run(self, image, img_src, src_info, **kwargs):
288
- img_bgr = image
289
- img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
290
- I_p_pstbk = torch.from_numpy(img_src).to(self.device).float()
291
- realtime = kwargs.get("realtime", False)
292
-
293
- if self.cfg.infer_params.flag_crop_driving_video:
294
- if self.src_lmk_pre is None:
295
- src_face = self.model_dict["face_analysis"].predict(img_bgr)
296
- if len(src_face) == 0:
297
- self.src_lmk_pre = None
298
- return None, None, None
299
- lmk = src_face[0]
300
- lmk = self.model_dict["landmark"].predict(img_rgb, lmk)
301
- self.src_lmk_pre = lmk.copy()
302
- else:
303
- lmk = self.model_dict["landmark"].predict(img_rgb, self.src_lmk_pre)
304
- self.src_lmk_pre = lmk.copy()
305
-
306
- ret_bbox = parse_bbox_from_landmark(
307
- lmk,
308
- scale=self.cfg.crop_params.dri_scale,
309
- vx_ratio_crop_video=self.cfg.crop_params.dri_vx_ratio,
310
- vy_ratio=self.cfg.crop_params.dri_vy_ratio,
311
- )["bbox"]
312
- global_bbox = [
313
- ret_bbox[0, 0],
314
- ret_bbox[0, 1],
315
- ret_bbox[2, 0],
316
- ret_bbox[2, 1],
317
- ]
318
- ret_dct = crop_image_by_bbox(
319
- img_rgb,
320
- global_bbox,
321
- lmk=lmk,
322
- dsize=kwargs.get("dsize", 512),
323
- flag_rot=False,
324
- borderValue=(0, 0, 0),
325
- )
326
- lmk_crop = ret_dct["lmk_crop"]
327
- img_crop = ret_dct["img_crop"]
328
- img_crop = cv2.resize(img_crop, (256, 256))
329
- else:
330
- if self.src_lmk_pre is None:
331
- src_face = self.model_dict["face_analysis"].predict(img_bgr)
332
- if len(src_face) == 0:
333
- self.src_lmk_pre = None
334
- return None, None, None
335
- lmk = src_face[0]
336
- lmk = self.model_dict["landmark"].predict(img_rgb, lmk)
337
- self.src_lmk_pre = lmk.copy()
338
- else:
339
- lmk = self.model_dict["landmark"].predict(img_rgb, self.src_lmk_pre)
340
- self.src_lmk_pre = lmk.copy()
341
- lmk_crop = lmk.copy()
342
- img_crop = cv2.resize(img_rgb, (256, 256))
343
-
344
- input_eye_ratio = calc_eye_close_ratio(lmk_crop[None])
345
- input_lip_ratio = calc_lip_close_ratio(lmk_crop[None])
346
- pitch, yaw, roll, t, exp, scale, kp = self.model_dict["motion_extractor"].predict(img_crop)
347
- x_d_i_info = {
348
- "pitch": pitch,
349
- "yaw": yaw,
350
- "roll": roll,
351
- "t": t,
352
- "exp": exp,
353
- "scale": scale,
354
- "kp": kp
355
- }
356
- R_d_i = get_rotation_matrix(pitch, yaw, roll)
357
-
358
- if kwargs.get("first_frame", False) or self.R_d_0 is None:
359
- self.R_d_0 = R_d_i.copy()
360
- self.x_d_0_info = copy.deepcopy(x_d_i_info)
361
- # realtime smooth
362
- self.R_d_smooth = utils.OneEuroFilter(4, 1)
363
- self.exp_smooth = utils.OneEuroFilter(4, 1)
364
- R_d_0 = self.R_d_0.copy()
365
- x_d_0_info = copy.deepcopy(self.x_d_0_info)
366
- out_crop, out_org = None, None
367
- for j in range(len(src_info)):
368
- x_s_info, source_lmk, R_s, f_s, x_s, x_c_s, lip_delta_before_animation, flag_lip_zero, mask_ori_float, M = \
369
- src_info[j]
370
- if self.cfg.infer_params.flag_relative_motion:
371
- if self.is_source_video:
372
- if self.cfg.infer_params.flag_video_editing_head_rotation:
373
- R_new = (R_d_i @ np.transpose(R_d_0, (0, 2, 1))) @ R_s
374
- R_new = self.R_d_smooth.process(R_new)
375
- else:
376
- R_new = R_s
377
- else:
378
- R_new = (R_d_i @ np.transpose(R_d_0, (0, 2, 1))) @ R_s
379
- delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
380
- if self.is_source_video:
381
- delta_new = self.exp_smooth.process(delta_new)
382
- scale_new = x_s_info['scale'] if self.is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
383
- t_new = x_s_info['t'] if self.is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
384
- else:
385
- if self.is_source_video:
386
- if self.cfg.infer_params.flag_video_editing_head_rotation:
387
- R_new = R_d_i
388
- R_new = self.R_d_smooth.process(R_new)
389
- else:
390
- R_new = R_s
391
- else:
392
- R_new = R_d_i
393
- delta_new = x_d_i_info['exp'].copy()
394
- if self.is_source_video:
395
- delta_new = self.exp_smooth.process(delta_new)
396
- scale_new = x_s_info['scale'].copy()
397
- t_new = x_d_i_info['t'].copy()
398
-
399
- t_new[..., 2] = 0 # zero tz
400
- x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
401
- if not self.is_animal:
402
- # Algorithm 1:
403
- if not self.cfg.infer_params.flag_stitching and not self.cfg.infer_params.flag_eye_retargeting and not self.cfg.infer_params.flag_lip_retargeting:
404
- # without stitching or retargeting
405
- if flag_lip_zero:
406
- x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
407
- else:
408
- pass
409
- elif self.cfg.infer_params.flag_stitching and not self.cfg.infer_params.flag_eye_retargeting and not self.cfg.infer_params.flag_lip_retargeting:
410
- # with stitching and without retargeting
411
- if flag_lip_zero:
412
- x_d_i_new = self.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(
413
- -1, x_s.shape[1], 3)
414
- else:
415
- x_d_i_new = self.stitching(x_s, x_d_i_new)
416
- else:
417
- eyes_delta, lip_delta = None, None
418
- if self.cfg.infer_params.flag_eye_retargeting:
419
- c_d_eyes_i = input_eye_ratio
420
- combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i,
421
- source_lmk)
422
- # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
423
- eyes_delta = self.retarget_eye(x_s, combined_eye_ratio_tensor)
424
- if self.cfg.infer_params.flag_lip_retargeting:
425
- c_d_lip_i = input_lip_ratio
426
- combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
427
- # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
428
- lip_delta = self.retarget_lip(x_s, combined_lip_ratio_tensor)
429
-
430
- if self.cfg.infer_params.flag_relative_motion: # use x_s
431
- x_d_i_new = x_s + \
432
- (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
433
- (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
434
- else: # use x_d,i
435
- x_d_i_new = x_d_i_new + \
436
- (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
437
- (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
438
-
439
- if self.cfg.infer_params.flag_stitching:
440
- x_d_i_new = self.stitching(x_s, x_d_i_new)
441
- else:
442
- if self.cfg.infer_params.flag_stitching:
443
- x_d_i_new = self.stitching(x_s, x_d_i_new)
444
-
445
- x_d_i_new = x_s + (x_d_i_new - x_s) * self.cfg.infer_params.driving_multiplier
446
- out_crop = self.model_dict["warping_spade"].predict(f_s, x_s, x_d_i_new)
447
- if not realtime and self.cfg.infer_params.flag_pasteback and self.cfg.infer_params.flag_do_crop and self.cfg.infer_params.flag_stitching:
448
- # TODO: pasteback is slow, considering optimize it using multi-threading or GPU
449
- # I_p_pstbk = paste_back(out_crop, crop_info['M_c2o'], I_p_pstbk, mask_ori_float)
450
- I_p_pstbk = paste_back_pytorch(out_crop, M, I_p_pstbk, mask_ori_float)
451
-
452
- return img_crop, out_crop.to(dtype=torch.uint8).cpu().numpy(), I_p_pstbk.to(dtype=torch.uint8).cpu().numpy()
453
-
454
- def __del__(self):
455
- self.clean_models()
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Author : wenshao
3
+ # @Email : [email protected]
4
+ # @Project : FasterLivePortrait
5
+ # @FileName: faster_live_portrait_pipeline.py
6
+
7
+ import copy
8
+ import pdb
9
+ import time
10
+ import traceback
11
+ from PIL import Image
12
+ import cv2
13
+ from tqdm import tqdm
14
+ import numpy as np
15
+ import torch
16
+
17
+ from .. import models
18
+ from ..utils.crop import crop_image, parse_bbox_from_landmark, crop_image_by_bbox, paste_back, paste_back_pytorch
19
+ from ..utils.utils import resize_to_limit, prepare_paste_back, get_rotation_matrix, calc_lip_close_ratio, \
20
+ calc_eye_close_ratio, transform_keypoint, concat_feat
21
+ from difpoint.src.utils import utils
22
+
23
+
24
+ class FasterLivePortraitPipeline:
25
+ def __init__(self, cfg, **kwargs):
26
+ self.cfg = cfg
27
+ self.init(**kwargs)
28
+
29
+ def init(self, **kwargs):
30
+ self.init_vars(**kwargs)
31
+ self.init_models(**kwargs)
32
+
33
+ def clean_models(self, **kwargs):
34
+ """
35
+ clean model
36
+ :param kwargs:
37
+ :return:
38
+ """
39
+ for key in list(self.model_dict.keys()):
40
+ del self.model_dict[key]
41
+ self.model_dict = {}
42
+
43
+ def init_models(self, **kwargs):
44
+ if not kwargs.get("is_animal", False):
45
+ print("load Human Model >>>")
46
+ self.is_animal = False
47
+ self.model_dict = {}
48
+ for model_name in self.cfg.models:
49
+ print(f"loading model: {model_name}")
50
+ print(self.cfg.models[model_name])
51
+ self.model_dict[model_name] = getattr(models, self.cfg.models[model_name]["name"])(
52
+ **self.cfg.models[model_name])
53
+ else:
54
+ print("load Animal Model >>>")
55
+ self.is_animal = True
56
+ self.model_dict = {}
57
+ from src.utils.animal_landmark_runner import XPoseRunner
58
+ from src.utils.utils import make_abs_path
59
+
60
+ xpose_ckpt_path: str = make_abs_path("../difpoint/checkpoints/liveportrait_animal_onnx/xpose.pth")
61
+ xpose_config_file_path: str = make_abs_path("models/XPose/config_model/UniPose_SwinT.py")
62
+ xpose_embedding_cache_path: str = make_abs_path('../difpoint/checkpoints/liveportrait_animal_onnx/clip_embedding')
63
+ self.model_dict["xpose"] = XPoseRunner(model_config_path=xpose_config_file_path,
64
+ model_checkpoint_path=xpose_ckpt_path,
65
+ embeddings_cache_path=xpose_embedding_cache_path,
66
+ flag_use_half_precision=True)
67
+ for model_name in self.cfg.animal_models:
68
+ print(f"loading model: {model_name}")
69
+ print(self.cfg.animal_models[model_name])
70
+ self.model_dict[model_name] = getattr(models, self.cfg.animal_models[model_name]["name"])(
71
+ **self.cfg.animal_models[model_name])
72
+
73
+ def init_vars(self, **kwargs):
74
+ self.mask_crop = cv2.imread(self.cfg.infer_params.mask_crop_path, cv2.IMREAD_COLOR)
75
+ self.frame_id = 0
76
+ self.src_lmk_pre = None
77
+ self.R_d_0 = None
78
+ self.x_d_0_info = None
79
+ self.R_d_smooth = utils.OneEuroFilter(4, 1)
80
+ self.exp_smooth = utils.OneEuroFilter(4, 1)
81
+
82
+ ## 记录source的信息
83
+ self.source_path = None
84
+ self.src_infos = []
85
+ self.src_imgs = []
86
+ self.is_source_video = False
87
+
88
+ self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
89
+
90
+ def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk):
91
+ c_s_eyes = calc_eye_close_ratio(source_lmk[None])
92
+ c_d_eyes_i = np.array(c_d_eyes_i).reshape(1, 1)
93
+ # [c_s,eyes, c_d,eyes,i]
94
+ combined_eye_ratio_tensor = np.concatenate([c_s_eyes, c_d_eyes_i], axis=1)
95
+ return combined_eye_ratio_tensor
96
+
97
+ def calc_combined_lip_ratio(self, c_d_lip_i, source_lmk):
98
+ c_s_lip = calc_lip_close_ratio(source_lmk[None])
99
+ c_d_lip_i = np.array(c_d_lip_i).reshape(1, 1) # 1x1
100
+ # [c_s,lip, c_d,lip,i]
101
+ combined_lip_ratio_tensor = np.concatenate([c_s_lip, c_d_lip_i], axis=1) # 1x2
102
+ return combined_lip_ratio_tensor
103
+
104
+ def prepare_source(self, source_path, **kwargs):
105
+ print(f"process source:{source_path} >>>>>>>>")
106
+ try:
107
+ if utils.is_image(source_path):
108
+ self.is_source_video = False
109
+ elif utils.is_video(source_path):
110
+ self.is_source_video = True
111
+ else: # source input is an unknown format
112
+ raise Exception(f"Unknown source format: {source_path}")
113
+
114
+ if self.is_source_video:
115
+ src_imgs_bgr = []
116
+ src_vcap = cv2.VideoCapture(source_path)
117
+ while True:
118
+ ret, frame = src_vcap.read()
119
+ if not ret:
120
+ break
121
+ src_imgs_bgr.append(frame)
122
+ src_vcap.release()
123
+ else:
124
+ img_bgr = cv2.imread(source_path, cv2.IMREAD_COLOR)
125
+ src_imgs_bgr = [img_bgr]
126
+
127
+ self.src_imgs = []
128
+ self.src_infos = []
129
+ self.source_path = source_path
130
+
131
+ for ii, img_bgr in tqdm(enumerate(src_imgs_bgr), total=len(src_imgs_bgr)):
132
+ img_bgr = resize_to_limit(img_bgr, self.cfg.infer_params.source_max_dim,
133
+ self.cfg.infer_params.source_division)
134
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
135
+ src_faces = []
136
+ if self.is_animal:
137
+ with torch.no_grad():
138
+ img_rgb_pil = Image.fromarray(img_rgb)
139
+ lmk = self.model_dict["xpose"].run(
140
+ img_rgb_pil,
141
+ 'face',
142
+ 'animal_face',
143
+ 0,
144
+ 0
145
+ )
146
+ if lmk is None:
147
+ continue
148
+ self.src_imgs.append(img_rgb)
149
+ src_faces.append(lmk)
150
+ else:
151
+ src_faces = self.model_dict["face_analysis"].predict(img_bgr)
152
+ if len(src_faces) == 0:
153
+ print("No face detected in the this image.")
154
+ continue
155
+ self.src_imgs.append(img_rgb)
156
+ # 如果是实时,只关注最大的那张脸
157
+ if kwargs.get("realtime", False):
158
+ src_faces = src_faces[:1]
159
+
160
+ crop_infos = []
161
+ for i in range(len(src_faces)):
162
+ # NOTE: temporarily only pick the first face, to support multiple face in the future
163
+ lmk = src_faces[i]
164
+ # crop the face
165
+ ret_dct = crop_image(
166
+ img_rgb, # ndarray
167
+ lmk, # 106x2 or Nx2
168
+ dsize=self.cfg.crop_params.src_dsize,
169
+ scale=self.cfg.crop_params.src_scale,
170
+ vx_ratio=self.cfg.crop_params.src_vx_ratio,
171
+ vy_ratio=self.cfg.crop_params.src_vy_ratio,
172
+ )
173
+ if self.is_animal:
174
+ ret_dct["lmk_crop"] = lmk
175
+ else:
176
+ lmk = self.model_dict["landmark"].predict(img_rgb, lmk)
177
+ ret_dct["lmk_crop"] = lmk
178
+ ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / self.cfg.crop_params.src_dsize
179
+
180
+ # update a 256x256 version for network input
181
+ ret_dct["img_crop_256x256"] = cv2.resize(
182
+ ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA
183
+ )
184
+ crop_infos.append(ret_dct)
185
+
186
+ src_infos = [[] for _ in range(len(crop_infos))]
187
+ for i, crop_info in enumerate(crop_infos):
188
+ source_lmk = crop_info['lmk_crop']
189
+ img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256']
190
+ pitch, yaw, roll, t, exp, scale, kp = self.model_dict["motion_extractor"].predict(
191
+ img_crop_256x256)
192
+ x_s_info = {
193
+ "pitch": pitch,
194
+ "yaw": yaw,
195
+ "roll": roll,
196
+ "t": t,
197
+ "exp": exp,
198
+ "scale": scale,
199
+ "kp": kp
200
+ }
201
+ src_infos[i].append(copy.deepcopy(x_s_info))
202
+ x_c_s = kp
203
+ R_s = get_rotation_matrix(pitch, yaw, roll)
204
+ f_s = self.model_dict["app_feat_extractor"].predict(img_crop_256x256)
205
+ x_s = transform_keypoint(pitch, yaw, roll, t, exp, scale, kp)
206
+ src_infos[i].extend([source_lmk.copy(), R_s.copy(), f_s.copy(), x_s.copy(), x_c_s.copy()])
207
+ if not self.is_animal:
208
+ flag_lip_zero = self.cfg.infer_params.flag_normalize_lip # not overwrite
209
+ if flag_lip_zero:
210
+ # let lip-open scalar to be 0 at first
211
+ c_d_lip_before_animation = [0.]
212
+ combined_lip_ratio_tensor_before_animation = self.calc_combined_lip_ratio(
213
+ c_d_lip_before_animation, source_lmk)
214
+ if combined_lip_ratio_tensor_before_animation[0][
215
+ 0] < self.cfg.infer_params.lip_normalize_threshold:
216
+ flag_lip_zero = False
217
+ src_infos[i].append(None)
218
+ src_infos[i].append(flag_lip_zero)
219
+ else:
220
+ lip_delta_before_animation = self.model_dict['stitching_lip_retarget'].predict(
221
+ concat_feat(x_s, combined_lip_ratio_tensor_before_animation))
222
+ src_infos[i].append(lip_delta_before_animation.copy())
223
+ src_infos[i].append(flag_lip_zero)
224
+ else:
225
+ src_infos[i].append(None)
226
+ src_infos[i].append(flag_lip_zero)
227
+ else:
228
+ src_infos[i].append(None)
229
+ src_infos[i].append(False)
230
+
231
+ ######## prepare for pasteback ########
232
+ if self.cfg.infer_params.flag_pasteback and self.cfg.infer_params.flag_do_crop and self.cfg.infer_params.flag_stitching:
233
+ mask_ori_float = prepare_paste_back(self.mask_crop, crop_info['M_c2o'],
234
+ dsize=(img_rgb.shape[1], img_rgb.shape[0]))
235
+ mask_ori_float = torch.from_numpy(mask_ori_float).to(self.device)
236
+ src_infos[i].append(mask_ori_float)
237
+ else:
238
+ src_infos[i].append(None)
239
+ M = torch.from_numpy(crop_info['M_c2o']).to(self.device)
240
+ src_infos[i].append(M)
241
+ self.src_infos.append(src_infos[:])
242
+ print(f"finish process source:{source_path} >>>>>>>>")
243
+ return len(self.src_infos) > 0
244
+ except Exception as e:
245
+ traceback.print_exc()
246
+ return False
247
+
248
+ def retarget_eye(self, kp_source, eye_close_ratio):
249
+ """
250
+ kp_source: BxNx3
251
+ eye_close_ratio: Bx3
252
+ Return: Bx(3*num_kp+2)
253
+ """
254
+ feat_eye = concat_feat(kp_source, eye_close_ratio)
255
+ delta = self.model_dict['stitching_eye_retarget'].predict(feat_eye)
256
+ return delta
257
+
258
+ def retarget_lip(self, kp_source, lip_close_ratio):
259
+ """
260
+ kp_source: BxNx3
261
+ lip_close_ratio: Bx2
262
+ """
263
+ feat_lip = concat_feat(kp_source, lip_close_ratio)
264
+ delta = self.model_dict['stitching_lip_retarget'].predict(feat_lip)
265
+ return delta
266
+
267
+ def stitching(self, kp_source, kp_driving):
268
+ """ conduct the stitching
269
+ kp_source: Bxnum_kpx3
270
+ kp_driving: Bxnum_kpx3
271
+ """
272
+
273
+ bs, num_kp = kp_source.shape[:2]
274
+
275
+ kp_driving_new = kp_driving.copy()
276
+
277
+ delta = self.model_dict['stitching'].predict(concat_feat(kp_source, kp_driving_new))
278
+
279
+ delta_exp = delta[..., :3 * num_kp].reshape(bs, num_kp, 3) # 1x20x3
280
+ delta_tx_ty = delta[..., 3 * num_kp:3 * num_kp + 2].reshape(bs, 1, 2) # 1x1x2
281
+
282
+ kp_driving_new += delta_exp
283
+ kp_driving_new[..., :2] += delta_tx_ty
284
+
285
+ return kp_driving_new
286
+
287
+ def run(self, image, img_src, src_info, **kwargs):
288
+ img_bgr = image
289
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
290
+ I_p_pstbk = torch.from_numpy(img_src).to(self.device).float()
291
+ realtime = kwargs.get("realtime", False)
292
+
293
+ if self.cfg.infer_params.flag_crop_driving_video:
294
+ if self.src_lmk_pre is None:
295
+ src_face = self.model_dict["face_analysis"].predict(img_bgr)
296
+ if len(src_face) == 0:
297
+ self.src_lmk_pre = None
298
+ return None, None, None
299
+ lmk = src_face[0]
300
+ lmk = self.model_dict["landmark"].predict(img_rgb, lmk)
301
+ self.src_lmk_pre = lmk.copy()
302
+ else:
303
+ lmk = self.model_dict["landmark"].predict(img_rgb, self.src_lmk_pre)
304
+ self.src_lmk_pre = lmk.copy()
305
+
306
+ ret_bbox = parse_bbox_from_landmark(
307
+ lmk,
308
+ scale=self.cfg.crop_params.dri_scale,
309
+ vx_ratio_crop_video=self.cfg.crop_params.dri_vx_ratio,
310
+ vy_ratio=self.cfg.crop_params.dri_vy_ratio,
311
+ )["bbox"]
312
+ global_bbox = [
313
+ ret_bbox[0, 0],
314
+ ret_bbox[0, 1],
315
+ ret_bbox[2, 0],
316
+ ret_bbox[2, 1],
317
+ ]
318
+ ret_dct = crop_image_by_bbox(
319
+ img_rgb,
320
+ global_bbox,
321
+ lmk=lmk,
322
+ dsize=kwargs.get("dsize", 512),
323
+ flag_rot=False,
324
+ borderValue=(0, 0, 0),
325
+ )
326
+ lmk_crop = ret_dct["lmk_crop"]
327
+ img_crop = ret_dct["img_crop"]
328
+ img_crop = cv2.resize(img_crop, (256, 256))
329
+ else:
330
+ if self.src_lmk_pre is None:
331
+ src_face = self.model_dict["face_analysis"].predict(img_bgr)
332
+ if len(src_face) == 0:
333
+ self.src_lmk_pre = None
334
+ return None, None, None
335
+ lmk = src_face[0]
336
+ lmk = self.model_dict["landmark"].predict(img_rgb, lmk)
337
+ self.src_lmk_pre = lmk.copy()
338
+ else:
339
+ lmk = self.model_dict["landmark"].predict(img_rgb, self.src_lmk_pre)
340
+ self.src_lmk_pre = lmk.copy()
341
+ lmk_crop = lmk.copy()
342
+ img_crop = cv2.resize(img_rgb, (256, 256))
343
+
344
+ input_eye_ratio = calc_eye_close_ratio(lmk_crop[None])
345
+ input_lip_ratio = calc_lip_close_ratio(lmk_crop[None])
346
+ pitch, yaw, roll, t, exp, scale, kp = self.model_dict["motion_extractor"].predict(img_crop)
347
+ x_d_i_info = {
348
+ "pitch": pitch,
349
+ "yaw": yaw,
350
+ "roll": roll,
351
+ "t": t,
352
+ "exp": exp,
353
+ "scale": scale,
354
+ "kp": kp
355
+ }
356
+ R_d_i = get_rotation_matrix(pitch, yaw, roll)
357
+
358
+ if kwargs.get("first_frame", False) or self.R_d_0 is None:
359
+ self.R_d_0 = R_d_i.copy()
360
+ self.x_d_0_info = copy.deepcopy(x_d_i_info)
361
+ # realtime smooth
362
+ self.R_d_smooth = utils.OneEuroFilter(4, 1)
363
+ self.exp_smooth = utils.OneEuroFilter(4, 1)
364
+ R_d_0 = self.R_d_0.copy()
365
+ x_d_0_info = copy.deepcopy(self.x_d_0_info)
366
+ out_crop, out_org = None, None
367
+ for j in range(len(src_info)):
368
+ x_s_info, source_lmk, R_s, f_s, x_s, x_c_s, lip_delta_before_animation, flag_lip_zero, mask_ori_float, M = \
369
+ src_info[j]
370
+ if self.cfg.infer_params.flag_relative_motion:
371
+ if self.is_source_video:
372
+ if self.cfg.infer_params.flag_video_editing_head_rotation:
373
+ R_new = (R_d_i @ np.transpose(R_d_0, (0, 2, 1))) @ R_s
374
+ R_new = self.R_d_smooth.process(R_new)
375
+ else:
376
+ R_new = R_s
377
+ else:
378
+ R_new = (R_d_i @ np.transpose(R_d_0, (0, 2, 1))) @ R_s
379
+ delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp'])
380
+ if self.is_source_video:
381
+ delta_new = self.exp_smooth.process(delta_new)
382
+ scale_new = x_s_info['scale'] if self.is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
383
+ t_new = x_s_info['t'] if self.is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
384
+ else:
385
+ if self.is_source_video:
386
+ if self.cfg.infer_params.flag_video_editing_head_rotation:
387
+ R_new = R_d_i
388
+ R_new = self.R_d_smooth.process(R_new)
389
+ else:
390
+ R_new = R_s
391
+ else:
392
+ R_new = R_d_i
393
+ delta_new = x_d_i_info['exp'].copy()
394
+ if self.is_source_video:
395
+ delta_new = self.exp_smooth.process(delta_new)
396
+ scale_new = x_s_info['scale'].copy()
397
+ t_new = x_d_i_info['t'].copy()
398
+
399
+ t_new[..., 2] = 0 # zero tz
400
+ x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
401
+ if not self.is_animal:
402
+ # Algorithm 1:
403
+ if not self.cfg.infer_params.flag_stitching and not self.cfg.infer_params.flag_eye_retargeting and not self.cfg.infer_params.flag_lip_retargeting:
404
+ # without stitching or retargeting
405
+ if flag_lip_zero:
406
+ x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
407
+ else:
408
+ pass
409
+ elif self.cfg.infer_params.flag_stitching and not self.cfg.infer_params.flag_eye_retargeting and not self.cfg.infer_params.flag_lip_retargeting:
410
+ # with stitching and without retargeting
411
+ if flag_lip_zero:
412
+ x_d_i_new = self.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(
413
+ -1, x_s.shape[1], 3)
414
+ else:
415
+ x_d_i_new = self.stitching(x_s, x_d_i_new)
416
+ else:
417
+ eyes_delta, lip_delta = None, None
418
+ if self.cfg.infer_params.flag_eye_retargeting:
419
+ c_d_eyes_i = input_eye_ratio
420
+ combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i,
421
+ source_lmk)
422
+ # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
423
+ eyes_delta = self.retarget_eye(x_s, combined_eye_ratio_tensor)
424
+ if self.cfg.infer_params.flag_lip_retargeting:
425
+ c_d_lip_i = input_lip_ratio
426
+ combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, source_lmk)
427
+ # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
428
+ lip_delta = self.retarget_lip(x_s, combined_lip_ratio_tensor)
429
+
430
+ if self.cfg.infer_params.flag_relative_motion: # use x_s
431
+ x_d_i_new = x_s + \
432
+ (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
433
+ (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
434
+ else: # use x_d,i
435
+ x_d_i_new = x_d_i_new + \
436
+ (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
437
+ (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
438
+
439
+ if self.cfg.infer_params.flag_stitching:
440
+ x_d_i_new = self.stitching(x_s, x_d_i_new)
441
+ else:
442
+ if self.cfg.infer_params.flag_stitching:
443
+ x_d_i_new = self.stitching(x_s, x_d_i_new)
444
+
445
+ x_d_i_new = x_s + (x_d_i_new - x_s) * self.cfg.infer_params.driving_multiplier
446
+ out_crop = self.model_dict["warping_spade"].predict(f_s, x_s, x_d_i_new)
447
+ if not realtime and self.cfg.infer_params.flag_pasteback and self.cfg.infer_params.flag_do_crop and self.cfg.infer_params.flag_stitching:
448
+ # TODO: pasteback is slow, considering optimize it using multi-threading or GPU
449
+ # I_p_pstbk = paste_back(out_crop, crop_info['M_c2o'], I_p_pstbk, mask_ori_float)
450
+ I_p_pstbk = paste_back_pytorch(out_crop, M, I_p_pstbk, mask_ori_float)
451
+
452
+ return img_crop, out_crop.to(dtype=torch.uint8).cpu().numpy(), I_p_pstbk.to(dtype=torch.uint8).cpu().numpy()
453
+
454
+ def __del__(self):
455
+ self.clean_models()