jhj0517 commited on
Commit
aa337d6
·
1 Parent(s): 6a40fe0

Fix type and bugs

Browse files
modules/live_portrait/live_portrait_inferencer.py CHANGED
@@ -243,14 +243,13 @@ class LivePortraitInferencer:
243
 
244
  def create_video(self,
245
  model_type: str = ModelType.HUMAN.value,
246
- retargeting_eyes: bool = True,
247
- retargeting_mouth: bool = True,
248
- tracking_src_vid: bool = True,
249
  animate_without_vid: bool = False,
250
  crop_factor: float = 1.5,
251
- src_image_list: Optional[List[np.ndarray]] = None,
252
  driving_vid_path: Optional[str] = None,
253
- driving_images: Optional[List[np.ndarray]] = None,
254
  progress: gr.Progress = gr.Progress()
255
  ):
256
  if self.pipeline is None or model_type != self.model_type:
@@ -260,16 +259,17 @@ class LivePortraitInferencer:
260
 
261
  src_length = 1
262
 
263
- if src_image_list is not None:
264
- src_length = len(src_image_list)
265
- if id(src_image_list) != id(self.src_image_list) or self.crop_factor != crop_factor:
266
  self.crop_factor = crop_factor
267
- self.src_image_list = src_image_list
268
  if 1 < src_length:
269
- self.psi_list = [self.prepare_source(src, crop_factor, True, tracking_src_vid) for src in src_image_list]
270
  else:
271
- self.psi_list = [self.prepare_source(src, crop_factor) for src in src_image_list]
272
 
 
273
  driving_length = 0
274
  if driving_images is not None:
275
  if id(driving_images) != id(self.driving_images):
@@ -281,6 +281,7 @@ class LivePortraitInferencer:
281
 
282
  if animate_without_vid:
283
  total_length = total_length
 
284
 
285
  c_i_es = ExpressionSet()
286
  c_o_es = ExpressionSet()
@@ -292,6 +293,7 @@ class LivePortraitInferencer:
292
 
293
  if i < src_length:
294
  psi = self.psi_list[i]
 
295
  s_info = psi.x_s_info
296
  s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t']))
297
 
@@ -299,7 +301,7 @@ class LivePortraitInferencer:
299
 
300
  if i < driving_length:
301
  d_i_info = self.driving_values[i]
302
- d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']])#.float().to(device="cuda:0")
303
 
304
  if d_0_es is None:
305
  d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t']))
@@ -532,8 +534,15 @@ class LivePortraitInferencer:
532
  return new_img
533
 
534
  def prepare_src_image(self, img):
535
- h, w = img.shape[:2]
536
- input_shape = [256,256]
 
 
 
 
 
 
 
537
  if h != input_shape[0] or w != input_shape[1]:
538
  if 256 < h: interpolation = cv2.INTER_AREA
539
  else: interpolation = cv2.INTER_LINEAR
@@ -604,11 +613,9 @@ class LivePortraitInferencer:
604
  return psi_list
605
 
606
  def prepare_driving_video(self, face_images):
607
- print("Prepare driving video...")
608
- f_img_np = (face_images * 255).byte().numpy()
609
-
610
  out_list = []
611
- for f_img in f_img_np:
612
  i_d = self.prepare_src_image(f_img)
613
  d_info = self.pipeline.get_kp_info(i_d)
614
  out_list.append(d_info)
 
243
 
244
  def create_video(self,
245
  model_type: str = ModelType.HUMAN.value,
246
+ retargeting_eyes: float = 0,
247
+ retargeting_mouth: float = 0,
248
+ tracking_src_vid: bool = False,
249
  animate_without_vid: bool = False,
250
  crop_factor: float = 1.5,
251
+ src_image: Optional[str] = None,
252
  driving_vid_path: Optional[str] = None,
 
253
  progress: gr.Progress = gr.Progress()
254
  ):
255
  if self.pipeline is None or model_type != self.model_type:
 
259
 
260
  src_length = 1
261
 
262
+ if src_image is not None:
263
+ src_length = len(src_image)
264
+ if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor:
265
  self.crop_factor = crop_factor
266
+
267
  if 1 < src_length:
268
+ self.psi_list = self.prepare_source(src_image, crop_factor, True, tracking_src_vid)
269
  else:
270
+ self.psi_list = self.prepare_source(src_image, crop_factor)
271
 
272
+ driving_images, vid_sound = extract_frames(driving_vid_path), extract_sound(driving_vid_path)
273
  driving_length = 0
274
  if driving_images is not None:
275
  if id(driving_images) != id(self.driving_images):
 
281
 
282
  if animate_without_vid:
283
  total_length = total_length
284
+ self.psi_list = [self.psi_list[0] for _ in range(total_length)]
285
 
286
  c_i_es = ExpressionSet()
287
  c_o_es = ExpressionSet()
 
293
 
294
  if i < src_length:
295
  psi = self.psi_list[i]
296
+
297
  s_info = psi.x_s_info
298
  s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t']))
299
 
 
301
 
302
  if i < driving_length:
303
  d_i_info = self.driving_values[i]
304
+ d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']]) #.float().to(device="cuda:0")
305
 
306
  if d_0_es is None:
307
  d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t']))
 
534
  return new_img
535
 
536
  def prepare_src_image(self, img):
537
+ if isinstance(img, str):
538
+ img = image_path_to_array(img)
539
+
540
+ if len(img.shape) <= 3:
541
+ img = img[np.newaxis, ...]
542
+
543
+ d, h, w, c = img.shape
544
+ img = img[0] # Select first dimension
545
+ input_shape = [256, 256]
546
  if h != input_shape[0] or w != input_shape[1]:
547
  if 256 < h: interpolation = cv2.INTER_AREA
548
  else: interpolation = cv2.INTER_LINEAR
 
613
  return psi_list
614
 
615
  def prepare_driving_video(self, face_images):
616
+ # print("Prepare driving video...")
 
 
617
  out_list = []
618
+ for f_img in face_images:
619
  i_d = self.prepare_src_image(f_img)
620
  d_info = self.pipeline.get_kp_info(i_d)
621
  out_list.append(d_info)