jhj0517
commited on
Commit
·
aa337d6
1
Parent(s):
6a40fe0
Fix type and bugs
Browse files
modules/live_portrait/live_portrait_inferencer.py
CHANGED
@@ -243,14 +243,13 @@ class LivePortraitInferencer:
|
|
243 |
|
244 |
def create_video(self,
|
245 |
model_type: str = ModelType.HUMAN.value,
|
246 |
-
retargeting_eyes:
|
247 |
-
retargeting_mouth:
|
248 |
-
tracking_src_vid: bool =
|
249 |
animate_without_vid: bool = False,
|
250 |
crop_factor: float = 1.5,
|
251 |
-
|
252 |
driving_vid_path: Optional[str] = None,
|
253 |
-
driving_images: Optional[List[np.ndarray]] = None,
|
254 |
progress: gr.Progress = gr.Progress()
|
255 |
):
|
256 |
if self.pipeline is None or model_type != self.model_type:
|
@@ -260,16 +259,17 @@ class LivePortraitInferencer:
|
|
260 |
|
261 |
src_length = 1
|
262 |
|
263 |
-
if
|
264 |
-
src_length = len(
|
265 |
-
if id(
|
266 |
self.crop_factor = crop_factor
|
267 |
-
|
268 |
if 1 < src_length:
|
269 |
-
self.psi_list =
|
270 |
else:
|
271 |
-
self.psi_list =
|
272 |
|
|
|
273 |
driving_length = 0
|
274 |
if driving_images is not None:
|
275 |
if id(driving_images) != id(self.driving_images):
|
@@ -281,6 +281,7 @@ class LivePortraitInferencer:
|
|
281 |
|
282 |
if animate_without_vid:
|
283 |
total_length = total_length
|
|
|
284 |
|
285 |
c_i_es = ExpressionSet()
|
286 |
c_o_es = ExpressionSet()
|
@@ -292,6 +293,7 @@ class LivePortraitInferencer:
|
|
292 |
|
293 |
if i < src_length:
|
294 |
psi = self.psi_list[i]
|
|
|
295 |
s_info = psi.x_s_info
|
296 |
s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t']))
|
297 |
|
@@ -299,7 +301,7 @@ class LivePortraitInferencer:
|
|
299 |
|
300 |
if i < driving_length:
|
301 |
d_i_info = self.driving_values[i]
|
302 |
-
d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']])#.float().to(device="cuda:0")
|
303 |
|
304 |
if d_0_es is None:
|
305 |
d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t']))
|
@@ -532,8 +534,15 @@ class LivePortraitInferencer:
|
|
532 |
return new_img
|
533 |
|
534 |
def prepare_src_image(self, img):
|
535 |
-
|
536 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
537 |
if h != input_shape[0] or w != input_shape[1]:
|
538 |
if 256 < h: interpolation = cv2.INTER_AREA
|
539 |
else: interpolation = cv2.INTER_LINEAR
|
@@ -604,11 +613,9 @@ class LivePortraitInferencer:
|
|
604 |
return psi_list
|
605 |
|
606 |
def prepare_driving_video(self, face_images):
|
607 |
-
print("Prepare driving video...")
|
608 |
-
f_img_np = (face_images * 255).byte().numpy()
|
609 |
-
|
610 |
out_list = []
|
611 |
-
for f_img in
|
612 |
i_d = self.prepare_src_image(f_img)
|
613 |
d_info = self.pipeline.get_kp_info(i_d)
|
614 |
out_list.append(d_info)
|
|
|
243 |
|
244 |
def create_video(self,
|
245 |
model_type: str = ModelType.HUMAN.value,
|
246 |
+
retargeting_eyes: float = 0,
|
247 |
+
retargeting_mouth: float = 0,
|
248 |
+
tracking_src_vid: bool = False,
|
249 |
animate_without_vid: bool = False,
|
250 |
crop_factor: float = 1.5,
|
251 |
+
src_image: Optional[str] = None,
|
252 |
driving_vid_path: Optional[str] = None,
|
|
|
253 |
progress: gr.Progress = gr.Progress()
|
254 |
):
|
255 |
if self.pipeline is None or model_type != self.model_type:
|
|
|
259 |
|
260 |
src_length = 1
|
261 |
|
262 |
+
if src_image is not None:
|
263 |
+
src_length = len(src_image)
|
264 |
+
if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor:
|
265 |
self.crop_factor = crop_factor
|
266 |
+
|
267 |
if 1 < src_length:
|
268 |
+
self.psi_list = self.prepare_source(src_image, crop_factor, True, tracking_src_vid)
|
269 |
else:
|
270 |
+
self.psi_list = self.prepare_source(src_image, crop_factor)
|
271 |
|
272 |
+
driving_images, vid_sound = extract_frames(driving_vid_path), extract_sound(driving_vid_path)
|
273 |
driving_length = 0
|
274 |
if driving_images is not None:
|
275 |
if id(driving_images) != id(self.driving_images):
|
|
|
281 |
|
282 |
if animate_without_vid:
|
283 |
total_length = total_length
|
284 |
+
self.psi_list = [self.psi_list[0] for _ in range(total_length)]
|
285 |
|
286 |
c_i_es = ExpressionSet()
|
287 |
c_o_es = ExpressionSet()
|
|
|
293 |
|
294 |
if i < src_length:
|
295 |
psi = self.psi_list[i]
|
296 |
+
|
297 |
s_info = psi.x_s_info
|
298 |
s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t']))
|
299 |
|
|
|
301 |
|
302 |
if i < driving_length:
|
303 |
d_i_info = self.driving_values[i]
|
304 |
+
d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']]) #.float().to(device="cuda:0")
|
305 |
|
306 |
if d_0_es is None:
|
307 |
d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t']))
|
|
|
534 |
return new_img
|
535 |
|
536 |
def prepare_src_image(self, img):
|
537 |
+
if isinstance(img, str):
|
538 |
+
img = image_path_to_array(img)
|
539 |
+
|
540 |
+
if len(img.shape) <= 3:
|
541 |
+
img = img[np.newaxis, ...]
|
542 |
+
|
543 |
+
d, h, w, c = img.shape
|
544 |
+
img = img[0] # Select first dimension
|
545 |
+
input_shape = [256, 256]
|
546 |
if h != input_shape[0] or w != input_shape[1]:
|
547 |
if 256 < h: interpolation = cv2.INTER_AREA
|
548 |
else: interpolation = cv2.INTER_LINEAR
|
|
|
613 |
return psi_list
|
614 |
|
615 |
def prepare_driving_video(self, face_images):
|
616 |
+
# print("Prepare driving video...")
|
|
|
|
|
617 |
out_list = []
|
618 |
+
for f_img in face_images:
|
619 |
i_d = self.prepare_src_image(f_img)
|
620 |
d_info = self.pipeline.get_kp_info(i_d)
|
621 |
out_list.append(d_info)
|