YinuoGuo27 commited on
Commit
3884a65
·
verified ·
1 Parent(s): 33a3a16

Update difpoint/inference.py

Browse files
Files changed (1) hide show
  1. difpoint/inference.py +15 -11
difpoint/inference.py CHANGED
@@ -403,7 +403,7 @@ class Inferencer(object):
403
  x_d_0_info = x_d_i_info
404
 
405
 
406
- if self.inf_cfg.infer_params.flag_relative_motion:
407
  R_new = (R_d_i.cpu().numpy() @ R_d_0.permute(0, 2, 1).cpu().numpy()) @ R_s
408
  delta_new = x_s_info['exp'].reshape(1, 21, -1) + (x_d_i_info['exp'] - x_d_0_info['exp'])
409
  scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
@@ -424,19 +424,22 @@ class Inferencer(object):
424
  x_d_i_new = x_d_i_new.cpu().numpy()
425
 
426
  # Algorithm 1:
427
- if not self.inf_cfg.infer_params.flag_stitching and not self.inf_cfg.infer_params.flag_eye_retargeting and not self.inf_cfg.infer_params.flag_lip_retargeting:
428
- pass
429
- elif self.inf_cfg.infer_params.flag_stitching and not self.inf_cfg.infer_params.flag_eye_retargeting and not self.inf_cfg.infer_params.flag_lip_retargeting:
 
 
 
 
430
  # with stitching and without retargeting
431
- x_d_i_new = self.live_portrait_pipeline.live_portrait_wrapper.stitching(x_s, x_d_i_new)
 
 
 
432
  else:
433
  eyes_delta, lip_delta = None, None
434
- if self.inf_cfg.infer_params.flag_eye_retargeting:
435
- pass
436
- if self.inf_cfg.infer_params.flag_lip_retargeting:
437
- pass
438
 
439
- if self.inf_cfg.infer_params.flag_relative_motion: # use x_s
440
  x_d_i_new = x_s + \
441
  (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
442
  (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
@@ -445,9 +448,10 @@ class Inferencer(object):
445
  (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
446
  (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
447
 
448
- if self.inf_cfg.infer_params.flag_stitching:
449
  x_d_i_new = self.live_portrait_pipeline.live_portrait_wrapper.stitching(x_s, x_d_i_new)
450
 
 
451
  out = self.live_portrait_pipeline.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
452
  I_p_i = self.live_portrait_pipeline.live_portrait_wrapper.parse_output(out['out'])[0]
453
  I_p_lst.append(I_p_i)
 
403
  x_d_0_info = x_d_i_info
404
 
405
 
406
+ if self.inf_cfg.flag_relative_motion:
407
  R_new = (R_d_i.cpu().numpy() @ R_d_0.permute(0, 2, 1).cpu().numpy()) @ R_s
408
  delta_new = x_s_info['exp'].reshape(1, 21, -1) + (x_d_i_info['exp'] - x_d_0_info['exp'])
409
  scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
 
424
  x_d_i_new = x_d_i_new.cpu().numpy()
425
 
426
  # Algorithm 1:
427
+ if not self.inf_cfg.flag_stitching and not self.inf_cfg.flag_eye_retargeting and not self.inf_cfg.flag_lip_retargeting:
428
+ # without stitching or retargeting
429
+ if flag_lip_zero:
430
+ x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
431
+ else:
432
+ pass
433
+ elif self.inf_cfg.flag_stitching and not self.inf_cfg.flag_eye_retargeting and not self.inf_cfg.flag_lip_retargeting:
434
  # with stitching and without retargeting
435
+ if flag_lip_zero:
436
+ x_d_i_new = self.live_portrait_pipeline.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
437
+ else:
438
+ x_d_i_new = self.live_portrait_pipeline.live_portrait_wrapper.stitching(x_s, x_d_i_new)
439
  else:
440
  eyes_delta, lip_delta = None, None
 
 
 
 
441
 
442
+ if self.inf_cfg.flag_relative_motion: # use x_s
443
  x_d_i_new = x_s + \
444
  (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
445
  (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
 
448
  (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
449
  (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
450
 
451
+ if self.inf_cfg.flag_stitching:
452
  x_d_i_new = self.live_portrait_pipeline.live_portrait_wrapper.stitching(x_s, x_d_i_new)
453
 
454
+
455
  out = self.live_portrait_pipeline.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
456
  I_p_i = self.live_portrait_pipeline.live_portrait_wrapper.parse_output(out['out'])[0]
457
  I_p_lst.append(I_p_i)