jhj0517
commited on
Commit
·
aa337d6
1
Parent(s):
6a40fe0
Fix type and bugs
Browse files
modules/live_portrait/live_portrait_inferencer.py
CHANGED
|
@@ -243,14 +243,13 @@ class LivePortraitInferencer:
|
|
| 243 |
|
| 244 |
def create_video(self,
|
| 245 |
model_type: str = ModelType.HUMAN.value,
|
| 246 |
-
retargeting_eyes:
|
| 247 |
-
retargeting_mouth:
|
| 248 |
-
tracking_src_vid: bool =
|
| 249 |
animate_without_vid: bool = False,
|
| 250 |
crop_factor: float = 1.5,
|
| 251 |
-
|
| 252 |
driving_vid_path: Optional[str] = None,
|
| 253 |
-
driving_images: Optional[List[np.ndarray]] = None,
|
| 254 |
progress: gr.Progress = gr.Progress()
|
| 255 |
):
|
| 256 |
if self.pipeline is None or model_type != self.model_type:
|
|
@@ -260,16 +259,17 @@ class LivePortraitInferencer:
|
|
| 260 |
|
| 261 |
src_length = 1
|
| 262 |
|
| 263 |
-
if
|
| 264 |
-
src_length = len(
|
| 265 |
-
if id(
|
| 266 |
self.crop_factor = crop_factor
|
| 267 |
-
|
| 268 |
if 1 < src_length:
|
| 269 |
-
self.psi_list =
|
| 270 |
else:
|
| 271 |
-
self.psi_list =
|
| 272 |
|
|
|
|
| 273 |
driving_length = 0
|
| 274 |
if driving_images is not None:
|
| 275 |
if id(driving_images) != id(self.driving_images):
|
|
@@ -281,6 +281,7 @@ class LivePortraitInferencer:
|
|
| 281 |
|
| 282 |
if animate_without_vid:
|
| 283 |
total_length = total_length
|
|
|
|
| 284 |
|
| 285 |
c_i_es = ExpressionSet()
|
| 286 |
c_o_es = ExpressionSet()
|
|
@@ -292,6 +293,7 @@ class LivePortraitInferencer:
|
|
| 292 |
|
| 293 |
if i < src_length:
|
| 294 |
psi = self.psi_list[i]
|
|
|
|
| 295 |
s_info = psi.x_s_info
|
| 296 |
s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t']))
|
| 297 |
|
|
@@ -299,7 +301,7 @@ class LivePortraitInferencer:
|
|
| 299 |
|
| 300 |
if i < driving_length:
|
| 301 |
d_i_info = self.driving_values[i]
|
| 302 |
-
d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']])#.float().to(device="cuda:0")
|
| 303 |
|
| 304 |
if d_0_es is None:
|
| 305 |
d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t']))
|
|
@@ -532,8 +534,15 @@ class LivePortraitInferencer:
|
|
| 532 |
return new_img
|
| 533 |
|
| 534 |
def prepare_src_image(self, img):
|
| 535 |
-
|
| 536 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
if h != input_shape[0] or w != input_shape[1]:
|
| 538 |
if 256 < h: interpolation = cv2.INTER_AREA
|
| 539 |
else: interpolation = cv2.INTER_LINEAR
|
|
@@ -604,11 +613,9 @@ class LivePortraitInferencer:
|
|
| 604 |
return psi_list
|
| 605 |
|
| 606 |
def prepare_driving_video(self, face_images):
|
| 607 |
-
print("Prepare driving video...")
|
| 608 |
-
f_img_np = (face_images * 255).byte().numpy()
|
| 609 |
-
|
| 610 |
out_list = []
|
| 611 |
-
for f_img in
|
| 612 |
i_d = self.prepare_src_image(f_img)
|
| 613 |
d_info = self.pipeline.get_kp_info(i_d)
|
| 614 |
out_list.append(d_info)
|
|
|
|
| 243 |
|
| 244 |
def create_video(self,
|
| 245 |
model_type: str = ModelType.HUMAN.value,
|
| 246 |
+
retargeting_eyes: float = 0,
|
| 247 |
+
retargeting_mouth: float = 0,
|
| 248 |
+
tracking_src_vid: bool = False,
|
| 249 |
animate_without_vid: bool = False,
|
| 250 |
crop_factor: float = 1.5,
|
| 251 |
+
src_image: Optional[str] = None,
|
| 252 |
driving_vid_path: Optional[str] = None,
|
|
|
|
| 253 |
progress: gr.Progress = gr.Progress()
|
| 254 |
):
|
| 255 |
if self.pipeline is None or model_type != self.model_type:
|
|
|
|
| 259 |
|
| 260 |
src_length = 1
|
| 261 |
|
| 262 |
+
if src_image is not None:
|
| 263 |
+
src_length = len(src_image)
|
| 264 |
+
if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor:
|
| 265 |
self.crop_factor = crop_factor
|
| 266 |
+
|
| 267 |
if 1 < src_length:
|
| 268 |
+
self.psi_list = self.prepare_source(src_image, crop_factor, True, tracking_src_vid)
|
| 269 |
else:
|
| 270 |
+
self.psi_list = self.prepare_source(src_image, crop_factor)
|
| 271 |
|
| 272 |
+
driving_images, vid_sound = extract_frames(driving_vid_path), extract_sound(driving_vid_path)
|
| 273 |
driving_length = 0
|
| 274 |
if driving_images is not None:
|
| 275 |
if id(driving_images) != id(self.driving_images):
|
|
|
|
| 281 |
|
| 282 |
if animate_without_vid:
|
| 283 |
total_length = total_length
|
| 284 |
+
self.psi_list = [self.psi_list[0] for _ in range(total_length)]
|
| 285 |
|
| 286 |
c_i_es = ExpressionSet()
|
| 287 |
c_o_es = ExpressionSet()
|
|
|
|
| 293 |
|
| 294 |
if i < src_length:
|
| 295 |
psi = self.psi_list[i]
|
| 296 |
+
|
| 297 |
s_info = psi.x_s_info
|
| 298 |
s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t']))
|
| 299 |
|
|
|
|
| 301 |
|
| 302 |
if i < driving_length:
|
| 303 |
d_i_info = self.driving_values[i]
|
| 304 |
+
d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']]) #.float().to(device="cuda:0")
|
| 305 |
|
| 306 |
if d_0_es is None:
|
| 307 |
d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t']))
|
|
|
|
| 534 |
return new_img
|
| 535 |
|
| 536 |
def prepare_src_image(self, img):
|
| 537 |
+
if isinstance(img, str):
|
| 538 |
+
img = image_path_to_array(img)
|
| 539 |
+
|
| 540 |
+
if len(img.shape) <= 3:
|
| 541 |
+
img = img[np.newaxis, ...]
|
| 542 |
+
|
| 543 |
+
d, h, w, c = img.shape
|
| 544 |
+
img = img[0] # Select first dimension
|
| 545 |
+
input_shape = [256, 256]
|
| 546 |
if h != input_shape[0] or w != input_shape[1]:
|
| 547 |
if 256 < h: interpolation = cv2.INTER_AREA
|
| 548 |
else: interpolation = cv2.INTER_LINEAR
|
|
|
|
| 613 |
return psi_list
|
| 614 |
|
| 615 |
def prepare_driving_video(self, face_images):
|
| 616 |
+
# print("Prepare driving video...")
|
|
|
|
|
|
|
| 617 |
out_list = []
|
| 618 |
+
for f_img in face_images:
|
| 619 |
i_d = self.prepare_src_image(f_img)
|
| 620 |
d_info = self.pipeline.get_kp_info(i_d)
|
| 621 |
out_list.append(d_info)
|