YinuoGuo27 commited on
Commit
5c7919e
·
verified ·
1 Parent(s): d0af744

Update difpoint/inference.py

Browse files
Files changed (1) hide show
  1. difpoint/inference.py +3 -3
difpoint/inference.py CHANGED
@@ -162,19 +162,19 @@ class Inferencer(object):
162
  self.device = 'cuda'
163
  from difpoint.model import get_model
164
  self.point_diffusion = get_model()
165
- ckpt = torch.load('/home/yinuo/Gradio-UI_copy/difpoint/outputs/2024.08.26_dim_70_frame_64_vox1_selected_d6.5_c8.5/2024-08-26--16-52-34/checkpoint-500000.pth')
166
 
167
  self.point_diffusion.load_state_dict(ckpt['model'])
168
  print('model', self.point_diffusion.children())
169
  self.point_diffusion.eval()
170
  self.point_diffusion.to(self.device)
171
 
172
- lm_croper_checkpoint = os.path.join('difpoint/dataset_process/ckpts/', 'shape_predictor_68_face_landmarks.dat')
173
  self.croper = Croper(lm_croper_checkpoint)
174
 
175
  self.norm_info = dict(np.load(r'difpoint/datasets/norm_info_d6.5_c8.5_vox1_train.npz'))
176
 
177
- wav2lip_checkpoint = 'difpoint/dataset_process/ckpts/wav2lip.pth'
178
  self.wav2lip_model = AudioEncoder(wav2lip_checkpoint, 'cuda')
179
  self.wav2lip_model.cuda()
180
  self.wav2lip_model.eval()
 
162
  self.device = 'cuda'
163
  from difpoint.model import get_model
164
  self.point_diffusion = get_model()
165
+ ckpt = torch.load('./downloaded_repo/ckpts/KDTalker.pth', weights_only=False)
166
 
167
  self.point_diffusion.load_state_dict(ckpt['model'])
168
  print('model', self.point_diffusion.children())
169
  self.point_diffusion.eval()
170
  self.point_diffusion.to(self.device)
171
 
172
+ lm_croper_checkpoint = './downloaded_repo/ckpts/shape_predictor_68_face_landmarks.dat'
173
  self.croper = Croper(lm_croper_checkpoint)
174
 
175
  self.norm_info = dict(np.load(r'difpoint/datasets/norm_info_d6.5_c8.5_vox1_train.npz'))
176
 
177
+ wav2lip_checkpoint = './downloaded_repo/ckpts/wav2lip.pth'
178
  self.wav2lip_model = AudioEncoder(wav2lip_checkpoint, 'cuda')
179
  self.wav2lip_model.cuda()
180
  self.wav2lip_model.eval()