Spaces:
Runtime error
Runtime error
Create live_portrait_wrapper_cpu.py
Browse files- src/live_portrait_wrapper_cpu.py +288 -0
src/live_portrait_wrapper_cpu.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding: utf-8
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Wrapper for LivePortrait core functions (CPU-optimized version)
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os.path as osp
|
| 8 |
+
import numpy as np
|
| 9 |
+
import cv2
|
| 10 |
+
import torch
|
| 11 |
+
import yaml
|
| 12 |
+
import psutil
|
| 13 |
+
|
| 14 |
+
from .utils.timer import Timer
|
| 15 |
+
from .utils.helper_cpu import load_model, concat_feat
|
| 16 |
+
from .utils.camera import headpose_pred_to_degree, get_rotation_matrix
|
| 17 |
+
from .utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio
|
| 18 |
+
from .config.inference_config import InferenceConfig
|
| 19 |
+
from .utils.rprint import rlog as log
|
| 20 |
+
|
| 21 |
+
class LivePortraitWrapperCPU(object):
|
| 22 |
+
|
| 23 |
+
def __init__(self, cfg: InferenceConfig):
|
| 24 |
+
model_config = yaml.load(open(cfg.models_config, 'r'), Loader=yaml.SafeLoader)
|
| 25 |
+
|
| 26 |
+
# Check available memory
|
| 27 |
+
available_memory = psutil.virtual_memory().available / (1024 * 1024 * 1024) # in GB
|
| 28 |
+
if available_memory < 2: # If less than 2GB available
|
| 29 |
+
log(f"Warning: Only {available_memory:.2f}GB of RAM available. This may cause performance issues or crashes.")
|
| 30 |
+
|
| 31 |
+
# init F
|
| 32 |
+
self.appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, 'cpu', 'appearance_feature_extractor')
|
| 33 |
+
log(f'Load appearance_feature_extractor done.')
|
| 34 |
+
# init M
|
| 35 |
+
self.motion_extractor = load_model(cfg.checkpoint_M, model_config, 'cpu', 'motion_extractor')
|
| 36 |
+
log(f'Load motion_extractor done.')
|
| 37 |
+
# init W
|
| 38 |
+
self.warping_module = load_model(cfg.checkpoint_W, model_config, 'cpu', 'warping_module')
|
| 39 |
+
log(f'Load warping_module done.')
|
| 40 |
+
# init G
|
| 41 |
+
self.spade_generator = load_model(cfg.checkpoint_G, model_config, 'cpu', 'spade_generator')
|
| 42 |
+
log(f'Load spade_generator done.')
|
| 43 |
+
# init S and R
|
| 44 |
+
if cfg.checkpoint_S is not None and osp.exists(cfg.checkpoint_S):
|
| 45 |
+
self.stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, 'cpu', 'stitching_retargeting_module')
|
| 46 |
+
log(f'Load stitching_retargeting_module done.')
|
| 47 |
+
else:
|
| 48 |
+
self.stitching_retargeting_module = None
|
| 49 |
+
self.device = 'cpu'
|
| 50 |
+
self.cfg = cfg
|
| 51 |
+
self.timer = Timer()
|
| 52 |
+
|
| 53 |
+
def update_config(self, user_args):
|
| 54 |
+
for k, v in user_args.items():
|
| 55 |
+
if hasattr(self.cfg, k):
|
| 56 |
+
setattr(self.cfg, k, v)
|
| 57 |
+
|
| 58 |
+
def prepare_source(self, img: np.ndarray) -> torch.Tensor:
|
| 59 |
+
""" construct the input as standard
|
| 60 |
+
img: HxWx3, uint8, 256x256
|
| 61 |
+
"""
|
| 62 |
+
h, w = img.shape[:2]
|
| 63 |
+
if h != self.cfg.input_shape[0] or w != self.cfg.input_shape[1]:
|
| 64 |
+
x = cv2.resize(img, (self.cfg.input_shape[0], self.cfg.input_shape[1]))
|
| 65 |
+
else:
|
| 66 |
+
x = img.copy()
|
| 67 |
+
|
| 68 |
+
if x.ndim == 3:
|
| 69 |
+
x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1
|
| 70 |
+
elif x.ndim == 4:
|
| 71 |
+
x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1
|
| 72 |
+
else:
|
| 73 |
+
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
|
| 74 |
+
x = np.clip(x, 0, 1) # clip to 0~1
|
| 75 |
+
x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
|
| 76 |
+
return x
|
| 77 |
+
|
| 78 |
+
def prepare_driving_videos(self, imgs) -> torch.Tensor:
|
| 79 |
+
""" construct the input as standard
|
| 80 |
+
imgs: NxBxHxWx3, uint8
|
| 81 |
+
"""
|
| 82 |
+
if isinstance(imgs, list):
|
| 83 |
+
_imgs = np.array(imgs)[..., np.newaxis] # TxHxWx3x1
|
| 84 |
+
elif isinstance(imgs, np.ndarray):
|
| 85 |
+
_imgs = imgs
|
| 86 |
+
else:
|
| 87 |
+
raise ValueError(f'imgs type error: {type(imgs)}')
|
| 88 |
+
|
| 89 |
+
y = _imgs.astype(np.float32) / 255.
|
| 90 |
+
y = np.clip(y, 0, 1) # clip to 0~1
|
| 91 |
+
y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) # TxHxWx3x1 -> Tx1x3xHxW
|
| 92 |
+
return y
|
| 93 |
+
|
| 94 |
+
def extract_feature_3d(self, x: torch.Tensor) -> torch.Tensor:
|
| 95 |
+
""" get the appearance feature of the image by F
|
| 96 |
+
x: Bx3xHxW, normalized to 0~1
|
| 97 |
+
"""
|
| 98 |
+
with torch.no_grad():
|
| 99 |
+
feature_3d = self.appearance_feature_extractor(x)
|
| 100 |
+
return feature_3d
|
| 101 |
+
|
| 102 |
+
def get_kp_info(self, x: torch.Tensor, **kwargs) -> dict:
|
| 103 |
+
""" get the implicit keypoint information
|
| 104 |
+
x: Bx3xHxW, normalized to 0~1
|
| 105 |
+
flag_refine_info: whether to transform the pose to degrees and the dimension of the reshape
|
| 106 |
+
return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp'
|
| 107 |
+
"""
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
kp_info = self.motion_extractor(x)
|
| 110 |
+
|
| 111 |
+
flag_refine_info: bool = kwargs.get('flag_refine_info', True)
|
| 112 |
+
if flag_refine_info:
|
| 113 |
+
bs = kp_info['kp'].shape[0]
|
| 114 |
+
kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None] # Bx1
|
| 115 |
+
kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None] # Bx1
|
| 116 |
+
kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None] # Bx1
|
| 117 |
+
kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3) # BxNx3
|
| 118 |
+
kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3) # BxNx3
|
| 119 |
+
|
| 120 |
+
return kp_info
|
| 121 |
+
|
| 122 |
+
def get_pose_dct(self, kp_info: dict) -> dict:
|
| 123 |
+
pose_dct = dict(
|
| 124 |
+
pitch=headpose_pred_to_degree(kp_info['pitch']).item(),
|
| 125 |
+
yaw=headpose_pred_to_degree(kp_info['yaw']).item(),
|
| 126 |
+
roll=headpose_pred_to_degree(kp_info['roll']).item(),
|
| 127 |
+
)
|
| 128 |
+
return pose_dct
|
| 129 |
+
|
| 130 |
+
def get_fs_and_kp_info(self, source_prepared, driving_first_frame):
|
| 131 |
+
# get the canonical keypoints of source image by M
|
| 132 |
+
source_kp_info = self.get_kp_info(source_prepared, flag_refine_info=True)
|
| 133 |
+
source_rotation = get_rotation_matrix(source_kp_info['pitch'], source_kp_info['yaw'], source_kp_info['roll'])
|
| 134 |
+
|
| 135 |
+
# get the canonical keypoints of first driving frame by M
|
| 136 |
+
driving_first_frame_kp_info = self.get_kp_info(driving_first_frame, flag_refine_info=True)
|
| 137 |
+
driving_first_frame_rotation = get_rotation_matrix(
|
| 138 |
+
driving_first_frame_kp_info['pitch'],
|
| 139 |
+
driving_first_frame_kp_info['yaw'],
|
| 140 |
+
driving_first_frame_kp_info['roll']
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# get feature volume by F
|
| 144 |
+
source_feature_3d = self.extract_feature_3d(source_prepared)
|
| 145 |
+
|
| 146 |
+
return source_kp_info, source_rotation, source_feature_3d, driving_first_frame_kp_info, driving_first_frame_rotation
|
| 147 |
+
|
| 148 |
+
def transform_keypoint(self, kp_info: dict):
|
| 149 |
+
"""
|
| 150 |
+
transform the implicit keypoints with the pose, shift, and expression deformation
|
| 151 |
+
kp: BxNx3
|
| 152 |
+
"""
|
| 153 |
+
kp = kp_info['kp'] # (bs, k, 3)
|
| 154 |
+
pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll']
|
| 155 |
+
|
| 156 |
+
t, exp = kp_info['t'], kp_info['exp']
|
| 157 |
+
scale = kp_info['scale']
|
| 158 |
+
|
| 159 |
+
pitch = headpose_pred_to_degree(pitch)
|
| 160 |
+
yaw = headpose_pred_to_degree(yaw)
|
| 161 |
+
roll = headpose_pred_to_degree(roll)
|
| 162 |
+
|
| 163 |
+
bs = kp.shape[0]
|
| 164 |
+
if kp.ndim == 2:
|
| 165 |
+
num_kp = kp.shape[1] // 3 # Bx(num_kpx3)
|
| 166 |
+
else:
|
| 167 |
+
num_kp = kp.shape[1] # Bxnum_kpx3
|
| 168 |
+
|
| 169 |
+
rot_mat = get_rotation_matrix(pitch, yaw, roll) # (bs, 3, 3)
|
| 170 |
+
|
| 171 |
+
# Eqn.2: s * (R * x_c,s + exp) + t
|
| 172 |
+
kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3)
|
| 173 |
+
kp_transformed *= scale[..., None] # (bs, k, 3) * (bs, 1, 1) = (bs, k, 3)
|
| 174 |
+
kp_transformed[:, :, 0:2] += t[:, None, 0:2] # remove z, only apply tx ty
|
| 175 |
+
|
| 176 |
+
return kp_transformed
|
| 177 |
+
|
| 178 |
+
def retarget_eye(self, kp_source: torch.Tensor, eye_close_ratio: torch.Tensor) -> torch.Tensor:
|
| 179 |
+
"""
|
| 180 |
+
kp_source: BxNx3
|
| 181 |
+
eye_close_ratio: Bx3
|
| 182 |
+
Return: Bx(3*num_kp+2)
|
| 183 |
+
"""
|
| 184 |
+
feat_eye = concat_feat(kp_source, eye_close_ratio)
|
| 185 |
+
|
| 186 |
+
with torch.no_grad():
|
| 187 |
+
delta = self.stitching_retargeting_module['eye'](feat_eye)
|
| 188 |
+
|
| 189 |
+
return delta
|
| 190 |
+
|
| 191 |
+
def retarget_lip(self, kp_source: torch.Tensor, lip_close_ratio: torch.Tensor) -> torch.Tensor:
|
| 192 |
+
"""
|
| 193 |
+
kp_source: BxNx3
|
| 194 |
+
lip_close_ratio: Bx2
|
| 195 |
+
"""
|
| 196 |
+
feat_lip = concat_feat(kp_source, lip_close_ratio)
|
| 197 |
+
|
| 198 |
+
with torch.no_grad():
|
| 199 |
+
delta = self.stitching_retargeting_module['lip'](feat_lip)
|
| 200 |
+
|
| 201 |
+
return delta
|
| 202 |
+
|
| 203 |
+
def stitch(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
|
| 204 |
+
"""
|
| 205 |
+
kp_source: BxNx3
|
| 206 |
+
kp_driving: BxNx3
|
| 207 |
+
Return: Bx(3*num_kp+2)
|
| 208 |
+
"""
|
| 209 |
+
feat_stiching = concat_feat(kp_source, kp_driving)
|
| 210 |
+
|
| 211 |
+
with torch.no_grad():
|
| 212 |
+
delta = self.stitching_retargeting_module['stitching'](feat_stiching)
|
| 213 |
+
|
| 214 |
+
return delta
|
| 215 |
+
|
| 216 |
+
def stitching(self, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
|
| 217 |
+
""" conduct the stitching
|
| 218 |
+
kp_source: Bxnum_kpx3
|
| 219 |
+
kp_driving: Bxnum_kpx3
|
| 220 |
+
"""
|
| 221 |
+
if self.stitching_retargeting_module is not None:
|
| 222 |
+
bs, num_kp = kp_source.shape[:2]
|
| 223 |
+
|
| 224 |
+
kp_driving_new = kp_driving.clone()
|
| 225 |
+
delta = self.stitch(kp_source, kp_driving_new)
|
| 226 |
+
|
| 227 |
+
delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) # 1x20x3
|
| 228 |
+
delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) # 1x1x2
|
| 229 |
+
|
| 230 |
+
kp_driving_new += delta_exp
|
| 231 |
+
kp_driving_new[..., :2] += delta_tx_ty
|
| 232 |
+
|
| 233 |
+
return kp_driving_new
|
| 234 |
+
|
| 235 |
+
return kp_driving
|
| 236 |
+
|
| 237 |
+
def warp_decode(self, feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor:
|
| 238 |
+
""" get the image after the warping of the implicit keypoints
|
| 239 |
+
feature_3d: Bx32x16x64x64, feature volume
|
| 240 |
+
kp_source: BxNx3
|
| 241 |
+
kp_driving: BxNx3
|
| 242 |
+
"""
|
| 243 |
+
# The line 18 in Algorithm 1: D(W(f_s; x_s, x′_d,i))
|
| 244 |
+
with torch.no_grad():
|
| 245 |
+
# get decoder input
|
| 246 |
+
ret_dct = self.warping_module(feature_3d, kp_source=kp_source, kp_driving=kp_driving)
|
| 247 |
+
# decode
|
| 248 |
+
ret_dct['out'] = self.spade_generator(feature=ret_dct['out'])
|
| 249 |
+
|
| 250 |
+
return ret_dct
|
| 251 |
+
|
| 252 |
+
def parse_output(self, out: torch.Tensor) -> np.ndarray:
|
| 253 |
+
""" construct the output as standard
|
| 254 |
+
return: 1xHxWx3, uint8
|
| 255 |
+
"""
|
| 256 |
+
out = np.transpose(out.data.numpy(), [0, 2, 3, 1]) # 1x3xHxW -> 1xHxWx3
|
| 257 |
+
out = np.clip(out, 0, 1) # clip to 0~1
|
| 258 |
+
out = np.clip(out * 255, 0, 255).astype(np.uint8) # 0~1 -> 0~255
|
| 259 |
+
|
| 260 |
+
return out
|
| 261 |
+
|
| 262 |
+
def calc_retargeting_ratio(self, source_lmk, driving_lmk_lst):
|
| 263 |
+
input_eye_ratio_lst = []
|
| 264 |
+
input_lip_ratio_lst = []
|
| 265 |
+
for lmk in driving_lmk_lst:
|
| 266 |
+
# for eyes retargeting
|
| 267 |
+
input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None]))
|
| 268 |
+
# for lip retargeting
|
| 269 |
+
input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None]))
|
| 270 |
+
return input_eye_ratio_lst, input_lip_ratio_lst
|
| 271 |
+
|
| 272 |
+
def calc_combined_eye_ratio(self, input_eye_ratio, source_lmk):
|
| 273 |
+
eye_close_ratio = calc_eye_close_ratio(source_lmk[None])
|
| 274 |
+
eye_close_ratio_tensor = torch.from_numpy(eye_close_ratio).float().to(self.device)
|
| 275 |
+
input_eye_ratio_tensor = torch.tensor([input_eye_ratio[0][0]]).reshape(1, 1).to(self.device)
|
| 276 |
+
# [c_s,eyes, c_d,eyes,i]
|
| 277 |
+
combined_eye_ratio_tensor = torch.cat([eye_close_ratio_tensor, input_eye_ratio_tensor], dim=1)
|
| 278 |
+
return combined_eye_ratio_tensor
|
| 279 |
+
|
| 280 |
+
def calc_combined_lip_ratio(self, input_lip_ratio, source_lmk):
|
| 281 |
+
lip_close_ratio = calc_lip_close_ratio(source_lmk[None])
|
| 282 |
+
lip_close_ratio_tensor = torch.from_numpy(lip_close_ratio).float().to(self.device)
|
| 283 |
+
# [c_s,lip, c_d,lip,i]
|
| 284 |
+
input_lip_ratio_tensor = torch.tensor([input_lip_ratio[0]]).to(self.device)
|
| 285 |
+
if input_lip_ratio_tensor.shape != torch.Size([1, 1]):
|
| 286 |
+
input_lip_ratio_tensor = input_lip_ratio_tensor.reshape(1, 1)
|
| 287 |
+
combined_lip_ratio_tensor = torch.cat([lip_close_ratio_tensor, input_lip_ratio_tensor], dim=1)
|
| 288 |
+
return combined_lip_ratio_tensor
|