Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # @Author : wenshao | |
| # @Email : [email protected] | |
| # @Project : FasterLivePortrait | |
| # @FileName: faster_live_portrait_pipeline.py | |
| import copy | |
| import pdb | |
| import time | |
| import traceback | |
| from PIL import Image | |
| import cv2 | |
| from tqdm import tqdm | |
| import numpy as np | |
| import torch | |
| from .. import models | |
| from ..utils.crop import crop_image, parse_bbox_from_landmark, crop_image_by_bbox, paste_back, paste_back_pytorch | |
| from ..utils.utils import resize_to_limit, prepare_paste_back, get_rotation_matrix, calc_lip_close_ratio, \ | |
| calc_eye_close_ratio, transform_keypoint, concat_feat | |
| from difpoint.src.utils import utils | |
| class FasterLivePortraitPipeline: | |
| def __init__(self, cfg, **kwargs): | |
| self.cfg = cfg | |
| self.init(**kwargs) | |
| def init(self, **kwargs): | |
| self.init_vars(**kwargs) | |
| self.init_models(**kwargs) | |
| def clean_models(self, **kwargs): | |
| """ | |
| clean model | |
| :param kwargs: | |
| :return: | |
| """ | |
| for key in list(self.model_dict.keys()): | |
| del self.model_dict[key] | |
| self.model_dict = {} | |
| def init_models(self, **kwargs): | |
| if not kwargs.get("is_animal", False): | |
| print("load Human Model >>>") | |
| self.is_animal = False | |
| self.model_dict = {} | |
| for model_name in self.cfg.models: | |
| print(f"loading model: {model_name}") | |
| print(self.cfg.models[model_name]) | |
| self.model_dict[model_name] = getattr(models, self.cfg.models[model_name]["name"])( | |
| **self.cfg.models[model_name]) | |
| else: | |
| print("load Animal Model >>>") | |
| self.is_animal = True | |
| self.model_dict = {} | |
| from src.utils.animal_landmark_runner import XPoseRunner | |
| from src.utils.utils import make_abs_path | |
| xpose_ckpt_path: str = make_abs_path("../difpoint/checkpoints/liveportrait_animal_onnx/xpose.pth") | |
| xpose_config_file_path: str = make_abs_path("models/XPose/config_model/UniPose_SwinT.py") | |
| xpose_embedding_cache_path: str = make_abs_path('../difpoint/checkpoints/liveportrait_animal_onnx/clip_embedding') | |
| self.model_dict["xpose"] = XPoseRunner(model_config_path=xpose_config_file_path, | |
| model_checkpoint_path=xpose_ckpt_path, | |
| embeddings_cache_path=xpose_embedding_cache_path, | |
| flag_use_half_precision=True) | |
| for model_name in self.cfg.animal_models: | |
| print(f"loading model: {model_name}") | |
| print(self.cfg.animal_models[model_name]) | |
| self.model_dict[model_name] = getattr(models, self.cfg.animal_models[model_name]["name"])( | |
| **self.cfg.animal_models[model_name]) | |
| def init_vars(self, **kwargs): | |
| self.mask_crop = cv2.imread(self.cfg.infer_params.mask_crop_path, cv2.IMREAD_COLOR) | |
| self.frame_id = 0 | |
| self.src_lmk_pre = None | |
| self.R_d_0 = None | |
| self.x_d_0_info = None | |
| self.R_d_smooth = utils.OneEuroFilter(4, 1) | |
| self.exp_smooth = utils.OneEuroFilter(4, 1) | |
| ## 记录source的信息 | |
| self.source_path = None | |
| self.src_infos = [] | |
| self.src_imgs = [] | |
| self.is_source_video = False | |
| self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| def calc_combined_eye_ratio(self, c_d_eyes_i, source_lmk): | |
| c_s_eyes = calc_eye_close_ratio(source_lmk[None]) | |
| c_d_eyes_i = np.array(c_d_eyes_i).reshape(1, 1) | |
| # [c_s,eyes, c_d,eyes,i] | |
| combined_eye_ratio_tensor = np.concatenate([c_s_eyes, c_d_eyes_i], axis=1) | |
| return combined_eye_ratio_tensor | |
| def calc_combined_lip_ratio(self, c_d_lip_i, source_lmk): | |
| c_s_lip = calc_lip_close_ratio(source_lmk[None]) | |
| c_d_lip_i = np.array(c_d_lip_i).reshape(1, 1) # 1x1 | |
| # [c_s,lip, c_d,lip,i] | |
| combined_lip_ratio_tensor = np.concatenate([c_s_lip, c_d_lip_i], axis=1) # 1x2 | |
| return combined_lip_ratio_tensor | |
| def prepare_source(self, source_path, **kwargs): | |
| print(f"process source:{source_path} >>>>>>>>") | |
| try: | |
| if utils.is_image(source_path): | |
| self.is_source_video = False | |
| elif utils.is_video(source_path): | |
| self.is_source_video = True | |
| else: # source input is an unknown format | |
| raise Exception(f"Unknown source format: {source_path}") | |
| if self.is_source_video: | |
| src_imgs_bgr = [] | |
| src_vcap = cv2.VideoCapture(source_path) | |
| while True: | |
| ret, frame = src_vcap.read() | |
| if not ret: | |
| break | |
| src_imgs_bgr.append(frame) | |
| src_vcap.release() | |
| else: | |
| img_bgr = cv2.imread(source_path, cv2.IMREAD_COLOR) | |
| src_imgs_bgr = [img_bgr] | |
| self.src_imgs = [] | |
| self.src_infos = [] | |
| self.source_path = source_path | |
| for ii, img_bgr in tqdm(enumerate(src_imgs_bgr), total=len(src_imgs_bgr)): | |
| img_bgr = resize_to_limit(img_bgr, self.cfg.infer_params.source_max_dim, | |
| self.cfg.infer_params.source_division) | |
| img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) | |
| src_faces = [] | |
| if self.is_animal: | |
| with torch.no_grad(): | |
| img_rgb_pil = Image.fromarray(img_rgb) | |
| lmk = self.model_dict["xpose"].run( | |
| img_rgb_pil, | |
| 'face', | |
| 'animal_face', | |
| 0, | |
| 0 | |
| ) | |
| if lmk is None: | |
| continue | |
| self.src_imgs.append(img_rgb) | |
| src_faces.append(lmk) | |
| else: | |
| src_faces = self.model_dict["face_analysis"].predict(img_bgr) | |
| if len(src_faces) == 0: | |
| print("No face detected in the this image.") | |
| continue | |
| self.src_imgs.append(img_rgb) | |
| # 如果是实时,只关注最大的那张脸 | |
| if kwargs.get("realtime", False): | |
| src_faces = src_faces[:1] | |
| crop_infos = [] | |
| for i in range(len(src_faces)): | |
| # NOTE: temporarily only pick the first face, to support multiple face in the future | |
| lmk = src_faces[i] | |
| # crop the face | |
| ret_dct = crop_image( | |
| img_rgb, # ndarray | |
| lmk, # 106x2 or Nx2 | |
| dsize=self.cfg.crop_params.src_dsize, | |
| scale=self.cfg.crop_params.src_scale, | |
| vx_ratio=self.cfg.crop_params.src_vx_ratio, | |
| vy_ratio=self.cfg.crop_params.src_vy_ratio, | |
| ) | |
| if self.is_animal: | |
| ret_dct["lmk_crop"] = lmk | |
| else: | |
| lmk = self.model_dict["landmark"].predict(img_rgb, lmk) | |
| ret_dct["lmk_crop"] = lmk | |
| ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / self.cfg.crop_params.src_dsize | |
| # update a 256x256 version for network input | |
| ret_dct["img_crop_256x256"] = cv2.resize( | |
| ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA | |
| ) | |
| crop_infos.append(ret_dct) | |
| src_infos = [[] for _ in range(len(crop_infos))] | |
| for i, crop_info in enumerate(crop_infos): | |
| source_lmk = crop_info['lmk_crop'] | |
| img_crop, img_crop_256x256 = crop_info['img_crop'], crop_info['img_crop_256x256'] | |
| pitch, yaw, roll, t, exp, scale, kp = self.model_dict["motion_extractor"].predict( | |
| img_crop_256x256) | |
| x_s_info = { | |
| "pitch": pitch, | |
| "yaw": yaw, | |
| "roll": roll, | |
| "t": t, | |
| "exp": exp, | |
| "scale": scale, | |
| "kp": kp | |
| } | |
| src_infos[i].append(copy.deepcopy(x_s_info)) | |
| x_c_s = kp | |
| R_s = get_rotation_matrix(pitch, yaw, roll) | |
| f_s = self.model_dict["app_feat_extractor"].predict(img_crop_256x256) | |
| x_s = transform_keypoint(pitch, yaw, roll, t, exp, scale, kp) | |
| src_infos[i].extend([source_lmk.copy(), R_s.copy(), f_s.copy(), x_s.copy(), x_c_s.copy()]) | |
| if not self.is_animal: | |
| flag_lip_zero = self.cfg.infer_params.flag_normalize_lip # not overwrite | |
| if flag_lip_zero: | |
| # let lip-open scalar to be 0 at first | |
| c_d_lip_before_animation = [0.] | |
| combined_lip_ratio_tensor_before_animation = self.calc_combined_lip_ratio( | |
| c_d_lip_before_animation, source_lmk) | |
| if combined_lip_ratio_tensor_before_animation[0][ | |
| 0] < self.cfg.infer_params.lip_normalize_threshold: | |
| flag_lip_zero = False | |
| src_infos[i].append(None) | |
| src_infos[i].append(flag_lip_zero) | |
| else: | |
| lip_delta_before_animation = self.model_dict['stitching_lip_retarget'].predict( | |
| concat_feat(x_s, combined_lip_ratio_tensor_before_animation)) | |
| src_infos[i].append(lip_delta_before_animation.copy()) | |
| src_infos[i].append(flag_lip_zero) | |
| else: | |
| src_infos[i].append(None) | |
| src_infos[i].append(flag_lip_zero) | |
| else: | |
| src_infos[i].append(None) | |
| src_infos[i].append(False) | |
| ######## prepare for pasteback ######## | |
| if self.cfg.infer_params.flag_pasteback and self.cfg.infer_params.flag_do_crop and self.cfg.infer_params.flag_stitching: | |
| mask_ori_float = prepare_paste_back(self.mask_crop, crop_info['M_c2o'], | |
| dsize=(img_rgb.shape[1], img_rgb.shape[0])) | |
| mask_ori_float = torch.from_numpy(mask_ori_float).to(self.device) | |
| src_infos[i].append(mask_ori_float) | |
| else: | |
| src_infos[i].append(None) | |
| M = torch.from_numpy(crop_info['M_c2o']).to(self.device) | |
| src_infos[i].append(M) | |
| self.src_infos.append(src_infos[:]) | |
| print(f"finish process source:{source_path} >>>>>>>>") | |
| return len(self.src_infos) > 0 | |
| except Exception as e: | |
| traceback.print_exc() | |
| return False | |
| def retarget_eye(self, kp_source, eye_close_ratio): | |
| """ | |
| kp_source: BxNx3 | |
| eye_close_ratio: Bx3 | |
| Return: Bx(3*num_kp+2) | |
| """ | |
| feat_eye = concat_feat(kp_source, eye_close_ratio) | |
| delta = self.model_dict['stitching_eye_retarget'].predict(feat_eye) | |
| return delta | |
| def retarget_lip(self, kp_source, lip_close_ratio): | |
| """ | |
| kp_source: BxNx3 | |
| lip_close_ratio: Bx2 | |
| """ | |
| feat_lip = concat_feat(kp_source, lip_close_ratio) | |
| delta = self.model_dict['stitching_lip_retarget'].predict(feat_lip) | |
| return delta | |
| def stitching(self, kp_source, kp_driving): | |
| """ conduct the stitching | |
| kp_source: Bxnum_kpx3 | |
| kp_driving: Bxnum_kpx3 | |
| """ | |
| bs, num_kp = kp_source.shape[:2] | |
| kp_driving_new = kp_driving.copy() | |
| delta = self.model_dict['stitching'].predict(concat_feat(kp_source, kp_driving_new)) | |
| delta_exp = delta[..., :3 * num_kp].reshape(bs, num_kp, 3) # 1x20x3 | |
| delta_tx_ty = delta[..., 3 * num_kp:3 * num_kp + 2].reshape(bs, 1, 2) # 1x1x2 | |
| kp_driving_new += delta_exp | |
| kp_driving_new[..., :2] += delta_tx_ty | |
| return kp_driving_new | |
| def run(self, image, img_src, src_info, **kwargs): | |
| img_bgr = image | |
| img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) | |
| I_p_pstbk = torch.from_numpy(img_src).to(self.device).float() | |
| realtime = kwargs.get("realtime", False) | |
| if self.cfg.infer_params.flag_crop_driving_video: | |
| if self.src_lmk_pre is None: | |
| src_face = self.model_dict["face_analysis"].predict(img_bgr) | |
| if len(src_face) == 0: | |
| self.src_lmk_pre = None | |
| return None, None, None | |
| lmk = src_face[0] | |
| lmk = self.model_dict["landmark"].predict(img_rgb, lmk) | |
| self.src_lmk_pre = lmk.copy() | |
| else: | |
| lmk = self.model_dict["landmark"].predict(img_rgb, self.src_lmk_pre) | |
| self.src_lmk_pre = lmk.copy() | |
| ret_bbox = parse_bbox_from_landmark( | |
| lmk, | |
| scale=self.cfg.crop_params.dri_scale, | |
| vx_ratio_crop_video=self.cfg.crop_params.dri_vx_ratio, | |
| vy_ratio=self.cfg.crop_params.dri_vy_ratio, | |
| )["bbox"] | |
| global_bbox = [ | |
| ret_bbox[0, 0], | |
| ret_bbox[0, 1], | |
| ret_bbox[2, 0], | |
| ret_bbox[2, 1], | |
| ] | |
| ret_dct = crop_image_by_bbox( | |
| img_rgb, | |
| global_bbox, | |
| lmk=lmk, | |
| dsize=kwargs.get("dsize", 512), | |
| flag_rot=False, | |
| borderValue=(0, 0, 0), | |
| ) | |
| lmk_crop = ret_dct["lmk_crop"] | |
| img_crop = ret_dct["img_crop"] | |
| img_crop = cv2.resize(img_crop, (256, 256)) | |
| else: | |
| if self.src_lmk_pre is None: | |
| src_face = self.model_dict["face_analysis"].predict(img_bgr) | |
| if len(src_face) == 0: | |
| self.src_lmk_pre = None | |
| return None, None, None | |
| lmk = src_face[0] | |
| lmk = self.model_dict["landmark"].predict(img_rgb, lmk) | |
| self.src_lmk_pre = lmk.copy() | |
| else: | |
| lmk = self.model_dict["landmark"].predict(img_rgb, self.src_lmk_pre) | |
| self.src_lmk_pre = lmk.copy() | |
| lmk_crop = lmk.copy() | |
| img_crop = cv2.resize(img_rgb, (256, 256)) | |
| input_eye_ratio = calc_eye_close_ratio(lmk_crop[None]) | |
| input_lip_ratio = calc_lip_close_ratio(lmk_crop[None]) | |
| pitch, yaw, roll, t, exp, scale, kp = self.model_dict["motion_extractor"].predict(img_crop) | |
| x_d_i_info = { | |
| "pitch": pitch, | |
| "yaw": yaw, | |
| "roll": roll, | |
| "t": t, | |
| "exp": exp, | |
| "scale": scale, | |
| "kp": kp | |
| } | |
| R_d_i = get_rotation_matrix(pitch, yaw, roll) | |
| if kwargs.get("first_frame", False) or self.R_d_0 is None: | |
| self.R_d_0 = R_d_i.copy() | |
| self.x_d_0_info = copy.deepcopy(x_d_i_info) | |
| # realtime smooth | |
| self.R_d_smooth = utils.OneEuroFilter(4, 1) | |
| self.exp_smooth = utils.OneEuroFilter(4, 1) | |
| R_d_0 = self.R_d_0.copy() | |
| x_d_0_info = copy.deepcopy(self.x_d_0_info) | |
| out_crop, out_org = None, None | |
| for j in range(len(src_info)): | |
| x_s_info, source_lmk, R_s, f_s, x_s, x_c_s, lip_delta_before_animation, flag_lip_zero, mask_ori_float, M = \ | |
| src_info[j] | |
| if self.cfg.infer_params.flag_relative_motion: | |
| if self.is_source_video: | |
| if self.cfg.infer_params.flag_video_editing_head_rotation: | |
| R_new = (R_d_i @ np.transpose(R_d_0, (0, 2, 1))) @ R_s | |
| R_new = self.R_d_smooth.process(R_new) | |
| else: | |
| R_new = R_s | |
| else: | |
| R_new = (R_d_i @ np.transpose(R_d_0, (0, 2, 1))) @ R_s | |
| delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) | |
| if self.is_source_video: | |
| delta_new = self.exp_smooth.process(delta_new) | |
| scale_new = x_s_info['scale'] if self.is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) | |
| t_new = x_s_info['t'] if self.is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) | |
| else: | |
| if self.is_source_video: | |
| if self.cfg.infer_params.flag_video_editing_head_rotation: | |
| R_new = R_d_i | |
| R_new = self.R_d_smooth.process(R_new) | |
| else: | |
| R_new = R_s | |
| else: | |
| R_new = R_d_i | |
| delta_new = x_d_i_info['exp'].copy() | |
| if self.is_source_video: | |
| delta_new = self.exp_smooth.process(delta_new) | |
| scale_new = x_s_info['scale'].copy() | |
| t_new = x_d_i_info['t'].copy() | |
| t_new[..., 2] = 0 # zero tz | |
| x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new | |
| if not self.is_animal: | |
| # Algorithm 1: | |
| if not self.cfg.infer_params.flag_stitching and not self.cfg.infer_params.flag_eye_retargeting and not self.cfg.infer_params.flag_lip_retargeting: | |
| # without stitching or retargeting | |
| if flag_lip_zero: | |
| x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) | |
| else: | |
| pass | |
| elif self.cfg.infer_params.flag_stitching and not self.cfg.infer_params.flag_eye_retargeting and not self.cfg.infer_params.flag_lip_retargeting: | |
| # with stitching and without retargeting | |
| if flag_lip_zero: | |
| x_d_i_new = self.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape( | |
| -1, x_s.shape[1], 3) | |
| else: | |
| x_d_i_new = self.stitching(x_s, x_d_i_new) | |
| else: | |
| eyes_delta, lip_delta = None, None | |
| if self.cfg.infer_params.flag_eye_retargeting: | |
| c_d_eyes_i = input_eye_ratio | |
| combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i, | |
| source_lmk) | |
| # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i) | |
| eyes_delta = self.retarget_eye(x_s, combined_eye_ratio_tensor) | |
| if self.cfg.infer_params.flag_lip_retargeting: | |
| c_d_lip_i = input_lip_ratio | |
| combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, source_lmk) | |
| # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) | |
| lip_delta = self.retarget_lip(x_s, combined_lip_ratio_tensor) | |
| if self.cfg.infer_params.flag_relative_motion: # use x_s | |
| x_d_i_new = x_s + \ | |
| (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ | |
| (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) | |
| else: # use x_d,i | |
| x_d_i_new = x_d_i_new + \ | |
| (eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ | |
| (lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) | |
| if self.cfg.infer_params.flag_stitching: | |
| x_d_i_new = self.stitching(x_s, x_d_i_new) | |
| else: | |
| if self.cfg.infer_params.flag_stitching: | |
| x_d_i_new = self.stitching(x_s, x_d_i_new) | |
| x_d_i_new = x_s + (x_d_i_new - x_s) * self.cfg.infer_params.driving_multiplier | |
| out_crop = self.model_dict["warping_spade"].predict(f_s, x_s, x_d_i_new) | |
| if not realtime and self.cfg.infer_params.flag_pasteback and self.cfg.infer_params.flag_do_crop and self.cfg.infer_params.flag_stitching: | |
| # TODO: pasteback is slow, considering optimize it using multi-threading or GPU | |
| # I_p_pstbk = paste_back(out_crop, crop_info['M_c2o'], I_p_pstbk, mask_ori_float) | |
| I_p_pstbk = paste_back_pytorch(out_crop, M, I_p_pstbk, mask_ori_float) | |
| return img_crop, out_crop.to(dtype=torch.uint8).cpu().numpy(), I_p_pstbk.to(dtype=torch.uint8).cpu().numpy() | |
| def __del__(self): | |
| self.clean_models() | |