YinuoGuo27 commited on
Commit
98adfb6
·
verified ·
1 Parent(s): bcf7449

Update difpoint/inference.py

Browse files
Files changed (1) hide show
  1. difpoint/inference.py +26 -26
difpoint/inference.py CHANGED
@@ -55,7 +55,16 @@ import os
55
  import datetime
56
  import platform
57
  from omegaconf import OmegaConf
58
- from difpoint.src.pipelines.faster_live_portrait_pipeline import FasterLivePortraitPipeline
 
 
 
 
 
 
 
 
 
59
 
60
  FFMPEG = "ffmpeg"
61
 
@@ -178,13 +187,12 @@ class Inferencer(object):
178
  self.wav2lip_model.cuda()
179
  self.wav2lip_model.eval()
180
 
181
- # specify configs for inference
182
- self.inf_cfg = OmegaConf.load("difpoint/configs/trt_mp_infer.yaml")
183
- self.inf_cfg.infer_params.flag_pasteback = False
184
 
185
- self.live_portrait_pipeline = FasterLivePortraitPipeline(cfg=self.inf_cfg, is_animal=False)
186
- #ret = self.live_portrait_pipeline.prepare_source(source_image)
187
 
 
188
  print('#'*25+f'End initialization, cost time {time.time()-st}'+'#'*25)
189
 
190
  def _norm(self, data_dict):
@@ -286,24 +294,15 @@ class Inferencer(object):
286
  else:
287
  input_image = image[0]
288
 
289
- I_s = (torch.FloatTensor(input_image.transpose((2, 0, 1))).unsqueeze(0).cuda() / 255).cpu().numpy()
290
- pitch, yaw, roll, t, exp, scale, kp = self.live_portrait_pipeline.model_dict["motion_extractor"].predict(
291
- I_s)
292
- x_s_info = {
293
- "pitch": pitch,
294
- "yaw": yaw,
295
- "roll": roll,
296
- "t": t,
297
- "exp": exp,
298
- "scale": scale,
299
- "kp": kp
300
- }
301
- x_c_s = kp.reshape(1, 21, -1)
302
  R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
303
- f_s = self.live_portrait_pipeline.model_dict["app_feat_extractor"].predict(I_s)
304
- x_s = transform_keypoint(pitch, yaw, roll, t, exp, scale, kp)
305
 
306
- flag_lip_zero = self.inf_cfg.infer_params.flag_normalize_lip
307
 
308
 
309
 
@@ -440,7 +439,7 @@ class Inferencer(object):
440
  pass
441
  elif self.inf_cfg.infer_params.flag_stitching and not self.inf_cfg.infer_params.flag_eye_retargeting and not self.inf_cfg.infer_params.flag_lip_retargeting:
442
  # with stitching and without retargeting
443
- x_d_i_new = self.live_portrait_pipeline.stitching(x_s, x_d_i_new)
444
  else:
445
  eyes_delta, lip_delta = None, None
446
  if self.inf_cfg.infer_params.flag_eye_retargeting:
@@ -458,10 +457,11 @@ class Inferencer(object):
458
  (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
459
 
460
  if self.inf_cfg.infer_params.flag_stitching:
461
- x_d_i_new = self.live_portrait_pipeline.stitching(x_s, x_d_i_new)
462
 
463
- out = self.live_portrait_pipeline.model_dict["warping_spade"].predict(f_s, x_s, x_d_i_new).cpu().numpy().astype(np.uint32)
464
- I_p_lst.append(out)
 
465
 
466
  video_name = os.path.basename(save_path)
467
  video_save_dir = os.path.dirname(save_path)
 
55
  import datetime
56
  import platform
57
  from omegaconf import OmegaConf
58
+ #from difpoint.src.pipelines.faster_live_portrait_pipeline import FasterLivePortraitPipeline
59
+ from difpoint.src.live_portrait_pipeline import LivePortraitPipeline
60
+ from difpointsrc.config.argument_config import ArgumentConfig
61
+ from difpoint.src.config.inference_config import InferenceConfig
62
+ from difpoint.src.config.crop_config import CropConfig
63
+ from difpoint.src.live_portrait_pipeline import LivePortraitPipeline
64
+ from difpoint.src.utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
65
+ from difpoint.src.utils.camera import get_rotation_matrix
66
+ from difpoint.src.utils.video import images2video, co
67
+
68
 
69
  FFMPEG = "ffmpeg"
70
 
 
187
  self.wav2lip_model.cuda()
188
  self.wav2lip_model.eval()
189
 
190
+ args = tyro.cli(ArgumentConfig)
 
 
191
 
192
+ self.inf_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
193
+ self.crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
194
 
195
+ self.live_portrait_pipeline = LivePortraitPipeline(inference_cfg=self.inf_cfg, crop_cfg=self.crop_cfg)
196
  print('#'*25+f'End initialization, cost time {time.time()-st}'+'#'*25)
197
 
198
  def _norm(self, data_dict):
 
294
  else:
295
  input_image = image[0]
296
 
297
+ I_s = torch.FloatTensor(input_image.transpose((2, 0, 1))).unsqueeze(0).cuda() / 255
298
+
299
+ x_s_info = self.live_portrait_pipeline.live_portrait_wrapper.get_kp_info(I_s)
300
+ x_c_s = x_s_info['kp'].reshape(1, 21, -1)
 
 
 
 
 
 
 
 
 
301
  R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
302
+ f_s = self.live_portrait_pipeline.live_portrait_wrapper.extract_feature_3d(I_s)
303
+ x_s = self.live_portrait_pipeline.live_portrait_wrapper.transform_keypoint(x_s_info)
304
 
305
+ flag_lip_zero = self.inf_cfg.flag_lip_zero # not overwrite
306
 
307
 
308
 
 
439
  pass
440
  elif self.inf_cfg.infer_params.flag_stitching and not self.inf_cfg.infer_params.flag_eye_retargeting and not self.inf_cfg.infer_params.flag_lip_retargeting:
441
  # with stitching and without retargeting
442
+ x_d_i_new = self.live_portrait_pipeline.live_portrait_wrapper.stitching(x_s, x_d_i_new)
443
  else:
444
  eyes_delta, lip_delta = None, None
445
  if self.inf_cfg.infer_params.flag_eye_retargeting:
 
457
  (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
458
 
459
  if self.inf_cfg.infer_params.flag_stitching:
460
+ x_d_i_new = self.live_portrait_pipeline.live_portrait_wrapper.stitching(x_s, x_d_i_new)
461
 
462
+ out = self.live_portrait_pipeline.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
463
+ I_p_i = self.live_portrait_pipeline.live_portrait_wrapper.parse_output(out['out'])[0]
464
+ I_p_lst.append(I_p_i)
465
 
466
  video_name = os.path.basename(save_path)
467
  video_save_dir = os.path.dirname(save_path)