# -*- coding: utf-8 -*- # @Author : wenshao # @Email : wenshaoguo0611@gmail.com # @Project : FasterLivePortrait # @FileName: gradio_live_portrait_pipeline.py import pdb import gradio as gr import cv2 import datetime import os import time from tqdm import tqdm import subprocess import numpy as np from .faster_live_portrait_pipeline import FasterLivePortraitPipeline from ..utils.utils import video_has_audio from ..utils.utils import resize_to_limit, prepare_paste_back, get_rotation_matrix, calc_lip_close_ratio, \ calc_eye_close_ratio, transform_keypoint, concat_feat from ..utils.crop import crop_image, parse_bbox_from_landmark, crop_image_by_bbox, paste_back, paste_back_pytorch from src.utils import utils import platform import torch from PIL import Image if platform.system().lower() == 'windows': FFMPEG = "third_party/ffmpeg-7.0.1-full_build/bin/ffmpeg.exe" else: FFMPEG = "ffmpeg" class GradioLivePortraitPipeline(FasterLivePortraitPipeline): def __init__(self, cfg, **kwargs): super(GradioLivePortraitPipeline, self).__init__(cfg, **kwargs) def update_cfg(self, args_user): update_ret = False for key in args_user: if key in self.cfg.infer_params: if self.cfg.infer_params[key] != args_user[key]: update_ret = True print("update infer cfg {} from {} to {}".format(key, self.cfg.infer_params[key], args_user[key])) self.cfg.infer_params[key] = args_user[key] elif key in self.cfg.crop_params: if self.cfg.crop_params[key] != args_user[key]: update_ret = True print("update crop cfg {} from {} to {}".format(key, self.cfg.crop_params[key], args_user[key])) self.cfg.crop_params[key] = args_user[key] else: if key in self.cfg.infer_params and self.cfg.infer_params[key] != args_user[key]: update_ret = True print("add {}:{} to infer cfg".format(key, args_user[key])) self.cfg.infer_params[key] = args_user[key] return update_ret def execute_video( self, input_source_image_path=None, input_source_video_path=None, input_driving_video_path=None, flag_relative_input=True, flag_do_crop_input=True, flag_remap_input=True, driving_multiplier=1.0, flag_stitching=True, flag_crop_driving_video_input=True, flag_video_editing_head_rotation=False, flag_is_animal=False, scale=2.3, vx_ratio=0.0, vy_ratio=-0.125, scale_crop_driving_video=2.2, vx_ratio_crop_driving_video=0.0, vy_ratio_crop_driving_video=-0.1, driving_smooth_observation_variance=1e-7, tab_selection=None, ): """ for video driven potrait animation """ if tab_selection == 'Image': input_source_path = input_source_image_path elif tab_selection == 'Video': input_source_path = input_source_video_path else: input_source_path = input_source_image_path if flag_is_animal != self.is_animal: self.init_models(is_animal=flag_is_animal) if input_source_path is not None and input_driving_video_path is not None: args_user = { 'source': input_source_path, 'driving': input_driving_video_path, 'flag_relative_motion': flag_relative_input, 'flag_do_crop': flag_do_crop_input, 'flag_pasteback': flag_remap_input, 'driving_multiplier': driving_multiplier, 'flag_stitching': flag_stitching, 'flag_crop_driving_video': flag_crop_driving_video_input, 'flag_video_editing_head_rotation': flag_video_editing_head_rotation, 'src_scale': scale, 'src_vx_ratio': vx_ratio, 'src_vy_ratio': vy_ratio, 'dri_scale': scale_crop_driving_video, 'dri_vx_ratio': vx_ratio_crop_driving_video, 'dri_vy_ratio': vy_ratio_crop_driving_video, 'driving_smooth_observation_variance': driving_smooth_observation_variance, } # update config from user input update_ret = self.update_cfg(args_user) # video driven animation video_path, video_path_concat, total_time = self.run_local(input_driving_video_path, input_source_path, update_ret=update_ret) gr.Info(f"Run successfully! Cost: {total_time} seconds!", duration=3) return video_path, video_path_concat, else: raise gr.Error("The input source portrait or driving video hasn't been prepared yet 💥!", duration=5) def run_local(self, driving_video_path, source_path, **kwargs): t00 = time.time() if self.source_path != source_path or kwargs.get("update_ret", False): # 如果不一样要重新初始化变量 self.init_vars(**kwargs) ret = self.prepare_source(source_path) if not ret: raise gr.Error(f"Error in processing source:{source_path} 💥!", duration=5) vcap = cv2.VideoCapture(driving_video_path) if self.is_source_video: duration, fps = utils.get_video_info(self.source_path) fps = int(fps) else: fps = int(vcap.get(cv2.CAP_PROP_FPS)) dframe = int(vcap.get(cv2.CAP_PROP_FRAME_COUNT)) if self.is_source_video: max_frame = min(dframe, len(self.src_imgs)) else: max_frame = dframe h, w = self.src_imgs[0].shape[:2] save_dir = f"./results/{datetime.datetime.now().strftime('%Y-%m-%d-%H%M%S')}" os.makedirs(save_dir, exist_ok=True) # render output video fourcc = cv2.VideoWriter_fourcc(*'mp4v') vsave_crop_path = os.path.join(save_dir, f"{os.path.basename(source_path)}-{os.path.basename(driving_video_path)}-crop.mp4") vout_crop = cv2.VideoWriter(vsave_crop_path, fourcc, fps, (512 * 2, 512)) vsave_org_path = os.path.join(save_dir, f"{os.path.basename(source_path)}-{os.path.basename(driving_video_path)}-org.mp4") vout_org = cv2.VideoWriter(vsave_org_path, fourcc, fps, (w, h)) infer_times = [] for i in tqdm(range(max_frame)): ret, frame = vcap.read() if not ret: break t0 = time.time() first_frame = i == 0 if self.is_source_video: dri_crop, out_crop, out_org = self.run(frame, self.src_imgs[i], self.src_infos[i], first_frame=first_frame) else: dri_crop, out_crop, out_org = self.run(frame, self.src_imgs[0], self.src_infos[0], first_frame=first_frame) if out_crop is None: print(f"no face in driving frame:{i}") continue infer_times.append(time.time() - t0) dri_crop = cv2.resize(dri_crop, (512, 512)) out_crop = np.concatenate([dri_crop, out_crop], axis=1) out_crop = cv2.cvtColor(out_crop, cv2.COLOR_RGB2BGR) vout_crop.write(out_crop) out_org = cv2.cvtColor(out_org, cv2.COLOR_RGB2BGR) vout_org.write(out_org) total_time = time.time() - t00 vcap.release() vout_crop.release() vout_org.release() if video_has_audio(driving_video_path): vsave_crop_path_new = os.path.splitext(vsave_crop_path)[0] + "-audio.mp4" vsave_org_path_new = os.path.splitext(vsave_org_path)[0] + "-audio.mp4" if self.is_source_video: duration, fps = utils.get_video_info(vsave_crop_path) subprocess.call( [FFMPEG, "-i", vsave_crop_path, "-i", driving_video_path, "-b:v", "10M", "-c:v", "libx264", "-map", "0:v", "-map", "1:a", "-c:a", "aac", "-pix_fmt", "yuv420p", "-shortest", # 以最短的流为基准 "-t", str(duration), # 设置时长 "-r", str(fps), # 设置帧率 vsave_crop_path_new, "-y"]) subprocess.call( [FFMPEG, "-i", vsave_org_path, "-i", driving_video_path, "-b:v", "10M", "-c:v", "libx264", "-map", "0:v", "-map", "1:a", "-c:a", "aac", "-pix_fmt", "yuv420p", "-shortest", # 以最短的流为基准 "-t", str(duration), # 设置时长 "-r", str(fps), # 设置帧率 vsave_org_path_new, "-y"]) else: subprocess.call( [FFMPEG, "-i", vsave_crop_path, "-i", driving_video_path, "-b:v", "10M", "-c:v", "libx264", "-map", "0:v", "-map", "1:a", "-c:a", "aac", "-pix_fmt", "yuv420p", vsave_crop_path_new, "-y", "-shortest"]) subprocess.call( [FFMPEG, "-i", vsave_org_path, "-i", driving_video_path, "-b:v", "10M", "-c:v", "libx264", "-map", "0:v", "-map", "1:a", "-c:a", "aac", "-pix_fmt", "yuv420p", vsave_org_path_new, "-y", "-shortest"]) return vsave_org_path_new, vsave_crop_path_new, total_time else: return vsave_org_path, vsave_crop_path, total_time def execute_image(self, input_eye_ratio: float, input_lip_ratio: float, input_image, flag_do_crop=True): """ for single image retargeting """ # disposable feature f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb = \ self.prepare_retargeting(input_image, flag_do_crop) if input_eye_ratio is None or input_lip_ratio is None: raise gr.Error("Invalid ratio input 💥!", duration=5) else: # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i) combined_eye_ratio_tensor = self.calc_combined_eye_ratio([[input_eye_ratio]], source_lmk_user) eyes_delta = self.retarget_eye(x_s_user, combined_eye_ratio_tensor) # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i) combined_lip_ratio_tensor = self.calc_combined_lip_ratio([[input_lip_ratio]], source_lmk_user) lip_delta = self.retarget_lip(x_s_user, combined_lip_ratio_tensor) num_kp = x_s_user.shape[1] # default: use x_s x_d_new = x_s_user + eyes_delta.reshape(-1, num_kp, 3) + lip_delta.reshape(-1, num_kp, 3) # D(W(f_s; x_s, x′_d)) out = self.model_dict["warping_spade"].predict(f_s_user, x_s_user, x_d_new) img_rgb = torch.from_numpy(img_rgb).to(self.device) out_to_ori_blend = paste_back_pytorch(out, crop_M_c2o, img_rgb, mask_ori) gr.Info("Run successfully!", duration=2) return out.to(dtype=torch.uint8).cpu().numpy(), out_to_ori_blend.to(dtype=torch.uint8).cpu().numpy() def prepare_retargeting(self, input_image, flag_do_crop=True): """ for single image retargeting """ if input_image is not None: ######## process source portrait ######## img_bgr = cv2.imread(input_image, cv2.IMREAD_COLOR) img_bgr = resize_to_limit(img_bgr, self.cfg.infer_params.source_max_dim, self.cfg.infer_params.source_division) img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) if self.is_animal: raise gr.Error("Animal Model Not Supported in Face Retarget 💥!", duration=5) else: src_faces = self.model_dict["face_analysis"].predict(img_bgr) if len(src_faces) == 0: raise gr.Error("No face detect in image 💥!", duration=5) src_faces = src_faces[:1] crop_infos = [] for i in range(len(src_faces)): # NOTE: temporarily only pick the first face, to support multiple face in the future lmk = src_faces[i] # crop the face ret_dct = crop_image( img_rgb, # ndarray lmk, # 106x2 or Nx2 dsize=self.cfg.crop_params.src_dsize, scale=self.cfg.crop_params.src_scale, vx_ratio=self.cfg.crop_params.src_vx_ratio, vy_ratio=self.cfg.crop_params.src_vy_ratio, ) lmk = self.model_dict["landmark"].predict(img_rgb, lmk) ret_dct["lmk_crop"] = lmk ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / self.cfg.crop_params.src_dsize # update a 256x256 version for network input ret_dct["img_crop_256x256"] = cv2.resize( ret_dct["img_crop"], (256, 256), interpolation=cv2.INTER_AREA ) ret_dct["lmk_crop_256x256"] = ret_dct["lmk_crop"] * 256 / self.cfg.crop_params.src_dsize crop_infos.append(ret_dct) crop_info = crop_infos[0] if flag_do_crop: I_s = crop_info['img_crop_256x256'].copy() else: I_s = img_rgb.copy() pitch, yaw, roll, t, exp, scale, kp = self.model_dict["motion_extractor"].predict(I_s) x_s_info = { "pitch": pitch, "yaw": yaw, "roll": roll, "t": t, "exp": exp, "scale": scale, "kp": kp } R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) ############################################ f_s_user = self.model_dict["app_feat_extractor"].predict(I_s) x_s_user = transform_keypoint(pitch, yaw, roll, t, exp, scale, kp) source_lmk_user = crop_info['lmk_crop'] crop_M_c2o = crop_info['M_c2o'] crop_M_c2o = torch.from_numpy(crop_M_c2o).to(self.device) mask_ori = prepare_paste_back(self.mask_crop, crop_info['M_c2o'], dsize=(img_rgb.shape[1], img_rgb.shape[0])) mask_ori = torch.from_numpy(mask_ori).to(self.device).float() return f_s_user, x_s_user, source_lmk_user, crop_M_c2o, mask_ori, img_rgb else: # when press the clear button, go here raise gr.Error("The retargeting input hasn't been prepared yet 💥!", duration=5)