Spaces:
Running
on
Zero
Running
on
Zero
Update difpoint/inference.py
Browse files- difpoint/inference.py +26 -26
difpoint/inference.py
CHANGED
@@ -55,7 +55,16 @@ import os
|
|
55 |
import datetime
|
56 |
import platform
|
57 |
from omegaconf import OmegaConf
|
58 |
-
from difpoint.src.pipelines.faster_live_portrait_pipeline import FasterLivePortraitPipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
FFMPEG = "ffmpeg"
|
61 |
|
@@ -178,13 +187,12 @@ class Inferencer(object):
|
|
178 |
self.wav2lip_model.cuda()
|
179 |
self.wav2lip_model.eval()
|
180 |
|
181 |
-
|
182 |
-
self.inf_cfg = OmegaConf.load("difpoint/configs/trt_mp_infer.yaml")
|
183 |
-
self.inf_cfg.infer_params.flag_pasteback = False
|
184 |
|
185 |
-
self.
|
186 |
-
|
187 |
|
|
|
188 |
print('#'*25+f'End initialization, cost time {time.time()-st}'+'#'*25)
|
189 |
|
190 |
def _norm(self, data_dict):
|
@@ -286,24 +294,15 @@ class Inferencer(object):
|
|
286 |
else:
|
287 |
input_image = image[0]
|
288 |
|
289 |
-
I_s =
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
"pitch": pitch,
|
294 |
-
"yaw": yaw,
|
295 |
-
"roll": roll,
|
296 |
-
"t": t,
|
297 |
-
"exp": exp,
|
298 |
-
"scale": scale,
|
299 |
-
"kp": kp
|
300 |
-
}
|
301 |
-
x_c_s = kp.reshape(1, 21, -1)
|
302 |
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
303 |
-
f_s = self.live_portrait_pipeline.
|
304 |
-
x_s = transform_keypoint(
|
305 |
|
306 |
-
flag_lip_zero = self.inf_cfg.
|
307 |
|
308 |
|
309 |
|
@@ -440,7 +439,7 @@ class Inferencer(object):
|
|
440 |
pass
|
441 |
elif self.inf_cfg.infer_params.flag_stitching and not self.inf_cfg.infer_params.flag_eye_retargeting and not self.inf_cfg.infer_params.flag_lip_retargeting:
|
442 |
# with stitching and without retargeting
|
443 |
-
x_d_i_new = self.live_portrait_pipeline.stitching(x_s, x_d_i_new)
|
444 |
else:
|
445 |
eyes_delta, lip_delta = None, None
|
446 |
if self.inf_cfg.infer_params.flag_eye_retargeting:
|
@@ -458,10 +457,11 @@ class Inferencer(object):
|
|
458 |
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
459 |
|
460 |
if self.inf_cfg.infer_params.flag_stitching:
|
461 |
-
x_d_i_new = self.live_portrait_pipeline.stitching(x_s, x_d_i_new)
|
462 |
|
463 |
-
out = self.live_portrait_pipeline.
|
464 |
-
|
|
|
465 |
|
466 |
video_name = os.path.basename(save_path)
|
467 |
video_save_dir = os.path.dirname(save_path)
|
|
|
55 |
import datetime
|
56 |
import platform
|
57 |
from omegaconf import OmegaConf
|
58 |
+
#from difpoint.src.pipelines.faster_live_portrait_pipeline import FasterLivePortraitPipeline
|
59 |
+
from difpoint.src.live_portrait_pipeline import LivePortraitPipeline
|
60 |
+
from difpointsrc.config.argument_config import ArgumentConfig
|
61 |
+
from difpoint.src.config.inference_config import InferenceConfig
|
62 |
+
from difpoint.src.config.crop_config import CropConfig
|
63 |
+
from difpoint.src.live_portrait_pipeline import LivePortraitPipeline
|
64 |
+
from difpoint.src.utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
|
65 |
+
from difpoint.src.utils.camera import get_rotation_matrix
|
66 |
+
from difpoint.src.utils.video import images2video, co
|
67 |
+
|
68 |
|
69 |
FFMPEG = "ffmpeg"
|
70 |
|
|
|
187 |
self.wav2lip_model.cuda()
|
188 |
self.wav2lip_model.eval()
|
189 |
|
190 |
+
args = tyro.cli(ArgumentConfig)
|
|
|
|
|
191 |
|
192 |
+
self.inf_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig
|
193 |
+
self.crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig
|
194 |
|
195 |
+
self.live_portrait_pipeline = LivePortraitPipeline(inference_cfg=self.inf_cfg, crop_cfg=self.crop_cfg)
|
196 |
print('#'*25+f'End initialization, cost time {time.time()-st}'+'#'*25)
|
197 |
|
198 |
def _norm(self, data_dict):
|
|
|
294 |
else:
|
295 |
input_image = image[0]
|
296 |
|
297 |
+
I_s = torch.FloatTensor(input_image.transpose((2, 0, 1))).unsqueeze(0).cuda() / 255
|
298 |
+
|
299 |
+
x_s_info = self.live_portrait_pipeline.live_portrait_wrapper.get_kp_info(I_s)
|
300 |
+
x_c_s = x_s_info['kp'].reshape(1, 21, -1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
|
302 |
+
f_s = self.live_portrait_pipeline.live_portrait_wrapper.extract_feature_3d(I_s)
|
303 |
+
x_s = self.live_portrait_pipeline.live_portrait_wrapper.transform_keypoint(x_s_info)
|
304 |
|
305 |
+
flag_lip_zero = self.inf_cfg.flag_lip_zero # not overwrite
|
306 |
|
307 |
|
308 |
|
|
|
439 |
pass
|
440 |
elif self.inf_cfg.infer_params.flag_stitching and not self.inf_cfg.infer_params.flag_eye_retargeting and not self.inf_cfg.infer_params.flag_lip_retargeting:
|
441 |
# with stitching and without retargeting
|
442 |
+
x_d_i_new = self.live_portrait_pipeline.live_portrait_wrapper.stitching(x_s, x_d_i_new)
|
443 |
else:
|
444 |
eyes_delta, lip_delta = None, None
|
445 |
if self.inf_cfg.infer_params.flag_eye_retargeting:
|
|
|
457 |
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
|
458 |
|
459 |
if self.inf_cfg.infer_params.flag_stitching:
|
460 |
+
x_d_i_new = self.live_portrait_pipeline.live_portrait_wrapper.stitching(x_s, x_d_i_new)
|
461 |
|
462 |
+
out = self.live_portrait_pipeline.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new)
|
463 |
+
I_p_i = self.live_portrait_pipeline.live_portrait_wrapper.parse_output(out['out'])[0]
|
464 |
+
I_p_lst.append(I_p_i)
|
465 |
|
466 |
video_name = os.path.basename(save_path)
|
467 |
video_save_dir = os.path.dirname(save_path)
|