jhj0517 commited on
Commit
45d5794
·
1 Parent(s): e5db983

raise error

Browse files
modules/live_portrait/live_portrait_inferencer.py CHANGED
@@ -252,76 +252,79 @@ class LivePortraitInferencer:
252
  model_type=model_type
253
  )
254
 
255
- vid_info = get_video_info(vid_input=driving_vid_path)
 
256
 
257
- if src_image is not None:
258
- if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor:
259
- self.crop_factor = crop_factor
260
- self.src_image = src_image
261
 
262
- self.psi_list = [self.prepare_source(src_image, crop_factor)]
263
 
264
- progress(0, desc="Extracting frames from the video..")
265
- driving_images, vid_sound = extract_frames(driving_vid_path, os.path.join(self.output_dir, "temp", "video_frames")), extract_sound(driving_vid_path)
266
 
267
- driving_length = 0
268
- if driving_images is not None:
269
- if id(driving_images) != id(self.driving_images):
270
- self.driving_images = driving_images
271
- self.driving_values = self.prepare_driving_video(driving_images)
272
- driving_length = len(self.driving_values)
273
 
274
- total_length = len(driving_images)
275
 
276
- c_i_es = ExpressionSet()
277
- c_o_es = ExpressionSet()
278
- d_0_es = None
279
 
280
- psi = None
281
- with torch.autocast(device_type=self.device, enabled=(self.device == "cuda")):
282
- for i in range(total_length):
283
 
284
- if i == 0:
285
- psi = self.psi_list[i]
286
- s_info = psi.x_s_info
287
- s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t']))
288
 
289
- new_es = ExpressionSet(es=s_es)
290
 
291
- if i < driving_length:
292
- d_i_info = self.driving_values[i]
293
- d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']]) # .float().to(device="cuda:0")
294
 
295
- if d_0_es is None:
296
- d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t']))
297
 
298
- self.retargeting(s_es.e, d_0_es.e, retargeting_eyes, (11, 13, 15, 16))
299
- self.retargeting(s_es.e, d_0_es.e, retargeting_mouth, (14, 17, 19, 20))
300
 
301
- new_es.e += d_i_info['exp'] - d_0_es.e
302
- new_es.r += d_i_r - d_0_es.r
303
- new_es.t += d_i_info['t'] - d_0_es.t
304
 
305
- r_new = get_rotation_matrix(
306
- s_info['pitch'] + new_es.r[0], s_info['yaw'] + new_es.r[1], s_info['roll'] + new_es.r[2])
307
- d_new = new_es.s * (new_es.e @ r_new) + new_es.t
308
- d_new = self.pipeline.stitching(psi.x_s_user, d_new)
309
- crop_out = self.pipeline.warp_decode(psi.f_s_user, psi.x_s_user, d_new)
310
- crop_out = self.pipeline.parse_output(crop_out['out'])[0]
311
 
312
- crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb),
313
- cv2.INTER_LINEAR)
314
- out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(
315
- np.uint8)
316
 
317
- out_frame_path = get_auto_incremental_file_path(os.path.join(self.output_dir, "temp", "video_frames", "out"), "png")
318
- save_image(out, out_frame_path)
319
 
320
- progress(i/total_length, desc=f"Generating frames {i}/{total_length} ..")
321
 
322
- video_path = create_video_from_frames(TEMP_VIDEO_OUT_FRAMES_DIR, frame_rate=vid_info.frame_rate, output_dir=os.path.join(self.output_dir, "videos"))
323
 
324
- return video_path
 
 
325
 
326
  def download_if_no_models(self,
327
  model_type: str = ModelType.HUMAN.value,
 
252
  model_type=model_type
253
  )
254
 
255
+ try:
256
+ vid_info = get_video_info(vid_input=driving_vid_path)
257
 
258
+ if src_image is not None:
259
+ if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor:
260
+ self.crop_factor = crop_factor
261
+ self.src_image = src_image
262
 
263
+ self.psi_list = [self.prepare_source(src_image, crop_factor)]
264
 
265
+ progress(0, desc="Extracting frames from the video..")
266
+ driving_images, vid_sound = extract_frames(driving_vid_path, os.path.join(self.output_dir, "temp", "video_frames")), extract_sound(driving_vid_path)
267
 
268
+ driving_length = 0
269
+ if driving_images is not None:
270
+ if id(driving_images) != id(self.driving_images):
271
+ self.driving_images = driving_images
272
+ self.driving_values = self.prepare_driving_video(driving_images)
273
+ driving_length = len(self.driving_values)
274
 
275
+ total_length = len(driving_images)
276
 
277
+ c_i_es = ExpressionSet()
278
+ c_o_es = ExpressionSet()
279
+ d_0_es = None
280
 
281
+ psi = None
282
+ with torch.autocast(device_type=self.device, enabled=(self.device == "cuda")):
283
+ for i in range(total_length):
284
 
285
+ if i == 0:
286
+ psi = self.psi_list[i]
287
+ s_info = psi.x_s_info
288
+ s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t']))
289
 
290
+ new_es = ExpressionSet(es=s_es)
291
 
292
+ if i < driving_length:
293
+ d_i_info = self.driving_values[i]
294
+ d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']]) # .float().to(device="cuda:0")
295
 
296
+ if d_0_es is None:
297
+ d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t']))
298
 
299
+ self.retargeting(s_es.e, d_0_es.e, retargeting_eyes, (11, 13, 15, 16))
300
+ self.retargeting(s_es.e, d_0_es.e, retargeting_mouth, (14, 17, 19, 20))
301
 
302
+ new_es.e += d_i_info['exp'] - d_0_es.e
303
+ new_es.r += d_i_r - d_0_es.r
304
+ new_es.t += d_i_info['t'] - d_0_es.t
305
 
306
+ r_new = get_rotation_matrix(
307
+ s_info['pitch'] + new_es.r[0], s_info['yaw'] + new_es.r[1], s_info['roll'] + new_es.r[2])
308
+ d_new = new_es.s * (new_es.e @ r_new) + new_es.t
309
+ d_new = self.pipeline.stitching(psi.x_s_user, d_new)
310
+ crop_out = self.pipeline.warp_decode(psi.f_s_user, psi.x_s_user, d_new)
311
+ crop_out = self.pipeline.parse_output(crop_out['out'])[0]
312
 
313
+ crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb),
314
+ cv2.INTER_LINEAR)
315
+ out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(
316
+ np.uint8)
317
 
318
+ out_frame_path = get_auto_incremental_file_path(os.path.join(self.output_dir, "temp", "video_frames", "out"), "png")
319
+ save_image(out, out_frame_path)
320
 
321
+ progress(i/total_length, desc=f"Generating frames {i}/{total_length} ..")
322
 
323
+ video_path = create_video_from_frames(TEMP_VIDEO_OUT_FRAMES_DIR, frame_rate=vid_info.frame_rate, output_dir=os.path.join(self.output_dir, "videos"))
324
 
325
+ return video_path
326
+ except Exception as e:
327
+ raise
328
 
329
  def download_if_no_models(self,
330
  model_type: str = ModelType.HUMAN.value,