YinuoGuo27 commited on
Commit
2c1a720
·
verified ·
1 Parent(s): 62e754e

Update difpoint/inference.py

Browse files
Files changed (1) hide show
  1. difpoint/inference.py +11 -29
difpoint/inference.py CHANGED
@@ -7,6 +7,9 @@
7
  import glob
8
 
9
  import os
 
 
 
10
  import os
11
  import time
12
  import shutil
@@ -54,7 +57,6 @@ import datetime
54
  import platform
55
  from omegaconf import OmegaConf
56
  from difpoint.src.pipelines.faster_live_portrait_pipeline import FasterLivePortraitPipeline
57
- import spaces
58
 
59
  FFMPEG = "ffmpeg"
60
 
@@ -160,19 +162,19 @@ class Inferencer(object):
160
  self.device = 'cuda'
161
  from difpoint.model import get_model
162
  self.point_diffusion = get_model()
163
- ckpt = torch.load('./downloaded_repo/ckpts/KDTalker.pth', weights_only=False)
164
 
165
  self.point_diffusion.load_state_dict(ckpt['model'])
166
  print('model', self.point_diffusion.children())
167
  self.point_diffusion.eval()
168
  self.point_diffusion.to(self.device)
169
 
170
- lm_croper_checkpoint = './downloaded_repo/ckpts/shape_predictor_68_face_landmarks.dat'
171
  self.croper = Croper(lm_croper_checkpoint)
172
 
173
  self.norm_info = dict(np.load(r'difpoint/datasets/norm_info_d6.5_c8.5_vox1_train.npz'))
174
 
175
- wav2lip_checkpoint = './downloaded_repo/ckpts/wav2lip.pth'
176
  self.wav2lip_model = AudioEncoder(wav2lip_checkpoint, 'cuda')
177
  self.wav2lip_model.cuda()
178
  self.wav2lip_model.eval()
@@ -270,7 +272,6 @@ class Inferencer(object):
270
  return combined_lip_ratio_tensor
271
 
272
  # 2024.06.26
273
- @spaces.GPU
274
  @torch.no_grad()
275
  def generate_with_audio_img(self, upload_audio_path, tts_audio_path, audio_type, image_path, smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t, save_path='results'):
276
  print(audio_type)
@@ -305,12 +306,7 @@ class Inferencer(object):
305
 
306
  flag_lip_zero = self.inf_cfg.infer_params.flag_normalize_lip
307
 
308
- if flag_lip_zero:
309
- # let lip-open scalar to be 0 at first
310
- c_d_lip_before_animation = [0.]
311
-
312
- lip_delta_before_animation = self.live_portrait_pipeline.model_dict['stitching_lip_retarget'].predict(
313
- concat_feat(x_s, combined_lip_ratio_tensor_before_animation))
314
 
315
  ######## process driving info ########
316
  kp_info = {}
@@ -442,30 +438,16 @@ class Inferencer(object):
442
 
443
  # Algorithm 1:
444
  if not self.inf_cfg.infer_params.flag_stitching and not self.inf_cfg.infer_params.flag_eye_retargeting and not self.inf_cfg.infer_params.flag_lip_retargeting:
445
- # without stitching or retargeting
446
- if flag_lip_zero:
447
- x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
448
- else:
449
- pass
450
  elif self.inf_cfg.infer_params.flag_stitching and not self.inf_cfg.infer_params.flag_eye_retargeting and not self.inf_cfg.infer_params.flag_lip_retargeting:
451
  # with stitching and without retargeting
452
- if flag_lip_zero:
453
- x_d_i_new = self.live_portrait_pipeline.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(
454
- -1, x_s.shape[1], 3)
455
- else:
456
- x_d_i_new = self.live_portrait_pipeline.stitching(x_s, x_d_i_new)
457
  else:
458
  eyes_delta, lip_delta = None, None
459
  if self.inf_cfg.infer_params.flag_eye_retargeting:
460
- c_d_eyes_i = template_dct['c_d_eyes_lst'][i]
461
- combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i, c_s_eye)
462
- # ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
463
- eyes_delta = self.live_portrait_pipeline.retarget_eye(x_s, combined_eye_ratio_tensor)
464
  if self.inf_cfg.infer_params.flag_lip_retargeting:
465
- c_d_lip_i = template_dct['c_d_lip_lst'][i]
466
- combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, c_s_lip)
467
- # ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
468
- lip_delta = self.live_portrait_pipeline.retarget_lip(x_s, combined_lip_ratio_tensor)
469
 
470
  if self.inf_cfg.infer_params.flag_relative_motion: # use x_s
471
  x_d_i_new = x_s + \
 
7
  import glob
8
 
9
  import os
10
+ os.environ['HYDRA_FULL_ERROR']='1'
11
+ os.environ['CUDA_VISIBLE_DEVICES'] = '2'
12
+
13
  import os
14
  import time
15
  import shutil
 
57
  import platform
58
  from omegaconf import OmegaConf
59
  from difpoint.src.pipelines.faster_live_portrait_pipeline import FasterLivePortraitPipeline
 
60
 
61
  FFMPEG = "ffmpeg"
62
 
 
162
  self.device = 'cuda'
163
  from difpoint.model import get_model
164
  self.point_diffusion = get_model()
165
+ ckpt = torch.load('/home/yinuo/Gradio-UI_copy/difpoint/outputs/2024.08.26_dim_70_frame_64_vox1_selected_d6.5_c8.5/2024-08-26--16-52-34/checkpoint-500000.pth')
166
 
167
  self.point_diffusion.load_state_dict(ckpt['model'])
168
  print('model', self.point_diffusion.children())
169
  self.point_diffusion.eval()
170
  self.point_diffusion.to(self.device)
171
 
172
+ lm_croper_checkpoint = os.path.join('difpoint/dataset_process/ckpts/', 'shape_predictor_68_face_landmarks.dat')
173
  self.croper = Croper(lm_croper_checkpoint)
174
 
175
  self.norm_info = dict(np.load(r'difpoint/datasets/norm_info_d6.5_c8.5_vox1_train.npz'))
176
 
177
+ wav2lip_checkpoint = 'difpoint/dataset_process/ckpts/wav2lip.pth'
178
  self.wav2lip_model = AudioEncoder(wav2lip_checkpoint, 'cuda')
179
  self.wav2lip_model.cuda()
180
  self.wav2lip_model.eval()
 
272
  return combined_lip_ratio_tensor
273
 
274
  # 2024.06.26
 
275
  @torch.no_grad()
276
  def generate_with_audio_img(self, upload_audio_path, tts_audio_path, audio_type, image_path, smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t, save_path='results'):
277
  print(audio_type)
 
306
 
307
  flag_lip_zero = self.inf_cfg.infer_params.flag_normalize_lip
308
 
309
+
 
 
 
 
 
310
 
311
  ######## process driving info ########
312
  kp_info = {}
 
438
 
439
  # Algorithm 1:
440
  if not self.inf_cfg.infer_params.flag_stitching and not self.inf_cfg.infer_params.flag_eye_retargeting and not self.inf_cfg.infer_params.flag_lip_retargeting:
441
+ pass
 
 
 
 
442
  elif self.inf_cfg.infer_params.flag_stitching and not self.inf_cfg.infer_params.flag_eye_retargeting and not self.inf_cfg.infer_params.flag_lip_retargeting:
443
  # with stitching and without retargeting
444
+ x_d_i_new = self.live_portrait_pipeline.stitching(x_s, x_d_i_new)
 
 
 
 
445
  else:
446
  eyes_delta, lip_delta = None, None
447
  if self.inf_cfg.infer_params.flag_eye_retargeting:
448
+ pass
 
 
 
449
  if self.inf_cfg.infer_params.flag_lip_retargeting:
450
+ pass
 
 
 
451
 
452
  if self.inf_cfg.infer_params.flag_relative_motion: # use x_s
453
  x_d_i_new = x_s + \