YinuoGuo27 commited on
Commit
0210789
·
verified ·
1 Parent(s): b109544

Update difpoint/inference.py

Browse files
Files changed (1) hide show
  1. difpoint/inference.py +3 -10
difpoint/inference.py CHANGED
@@ -393,8 +393,6 @@ class Inferencer(object):
393
 
394
  for key in x_d_i_info:
395
  x_d_i_info[key] = torch.tensor(x_d_i_info[key]).cuda()
396
- for key in x_s_info:
397
- x_s_info[key] = torch.tensor(x_s_info[key]).cuda()
398
 
399
  R_d_i = x_d_i_info['R_d']
400
 
@@ -404,7 +402,7 @@ class Inferencer(object):
404
 
405
 
406
  if self.inf_cfg.flag_relative_motion:
407
- R_new = (R_d_i.cpu().numpy() @ R_d_0.permute(0, 2, 1).cpu().numpy()) @ R_s
408
  delta_new = x_s_info['exp'].reshape(1, 21, -1) + (x_d_i_info['exp'] - x_d_0_info['exp'])
409
  scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
410
  t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
@@ -413,15 +411,10 @@ class Inferencer(object):
413
  delta_new = x_d_i_info['exp']
414
  scale_new = x_s_info['scale']
415
  t_new = x_d_i_info['t']
416
-
417
  t_new[..., 2] = 0 # zero tz
418
- x_c_s = torch.tensor(x_c_s, dtype=torch.float32).cuda()
419
- R_new = torch.tensor(R_new, dtype=torch.float32).cuda()
420
- delta_new = torch.tensor(delta_new, dtype=torch.float32).cuda()
421
- t_new = torch.tensor(t_new, dtype=torch.float32).cuda()
422
- scale_new = torch.tensor(scale_new, dtype=torch.float32).cuda()
423
  x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
424
- x_d_i_new = x_d_i_new.cpu().numpy()
425
 
426
  # Algorithm 1:
427
  if not self.inf_cfg.flag_stitching and not self.inf_cfg.flag_eye_retargeting and not self.inf_cfg.flag_lip_retargeting:
 
393
 
394
  for key in x_d_i_info:
395
  x_d_i_info[key] = torch.tensor(x_d_i_info[key]).cuda()
 
 
396
 
397
  R_d_i = x_d_i_info['R_d']
398
 
 
402
 
403
 
404
  if self.inf_cfg.flag_relative_motion:
405
+ R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s
406
  delta_new = x_s_info['exp'].reshape(1, 21, -1) + (x_d_i_info['exp'] - x_d_0_info['exp'])
407
  scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
408
  t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
 
411
  delta_new = x_d_i_info['exp']
412
  scale_new = x_s_info['scale']
413
  t_new = x_d_i_info['t']
 
414
  t_new[..., 2] = 0 # zero tz
415
+
 
 
 
 
416
  x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
417
+
418
 
419
  # Algorithm 1:
420
  if not self.inf_cfg.flag_stitching and not self.inf_cfg.flag_eye_retargeting and not self.inf_cfg.flag_lip_retargeting: