robinwitch commited on
Commit
f01632b
·
1 Parent(s): 279199b
Files changed (1) hide show
  1. app.py +397 -379
app.py CHANGED
@@ -242,383 +242,6 @@ class BaseTrainer(object):
242
  original_shape_t[i, selected_indices] = filtered_t[i]
243
  return original_shape_t
244
 
245
- def _load_data(self, dict_data):
246
- tar_pose_raw = dict_data["pose"]
247
- tar_pose = tar_pose_raw[:, :, :165].to(self.rank)
248
- tar_contact = tar_pose_raw[:, :, 165:169].to(self.rank)
249
- tar_trans = dict_data["trans"].to(self.rank)
250
- tar_trans_v = dict_data["trans_v"].to(self.rank)
251
- tar_exps = dict_data["facial"].to(self.rank)
252
- in_audio = dict_data["audio"].to(self.rank)
253
- in_word = dict_data["word"].to(self.rank)
254
- tar_beta = dict_data["beta"].to(self.rank)
255
- tar_id = dict_data["id"].to(self.rank).long()
256
- bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints
257
-
258
- tar_pose_jaw = tar_pose[:, :, 66:69]
259
- tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3))
260
- tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6)
261
- tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2)
262
-
263
- tar_pose_hands = tar_pose[:, :, 25*3:55*3]
264
- tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3))
265
- tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6)
266
-
267
- tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)]
268
- tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3))
269
- tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6)
270
-
271
- tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)]
272
- tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3))
273
- tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6)
274
-
275
- tar_pose_lower = tar_pose_leg
276
-
277
-
278
- tar4dis = torch.cat([tar_pose_jaw, tar_pose_upper, tar_pose_hands, tar_pose_leg], dim=2)
279
-
280
-
281
- if self.args.pose_norm:
282
- tar_pose_upper = (tar_pose_upper - self.mean_upper) / self.std_upper
283
- tar_pose_hands = (tar_pose_hands - self.mean_hands) / self.std_hands
284
- tar_pose_lower = (tar_pose_lower - self.mean_lower) / self.std_lower
285
-
286
- if self.use_trans:
287
- tar_trans_v = (tar_trans_v - self.trans_mean)/self.trans_std
288
- tar_pose_lower = torch.cat([tar_pose_lower,tar_trans_v], dim=-1)
289
-
290
- latent_face_top = None#self.vq_model_face.map2latent(tar_pose_face) # bs*n/4
291
- latent_upper_top = self.vq_model_upper.map2latent(tar_pose_upper)
292
- latent_hands_top = self.vq_model_hands.map2latent(tar_pose_hands)
293
- latent_lower_top = self.vq_model_lower.map2latent(tar_pose_lower)
294
-
295
- latent_in = torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2)/self.args.vqvae_latent_scale
296
-
297
-
298
- tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3))
299
- tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6)
300
- latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1)
301
- style_feature = None
302
- if self.args.use_motionclip:
303
- motionclip_feat = tar_pose_6d[...,:22*6]
304
- batch = {}
305
- bs,seq,feat = motionclip_feat.shape
306
- batch['x']=motionclip_feat.permute(0,2,1).contiguous()
307
- batch['y']=torch.zeros(bs).int().cuda()
308
- batch['mask']=torch.ones([bs,seq]).bool().cuda()
309
- style_feature = self.motionclip.encoder(batch)['mu'].detach().float()
310
-
311
-
312
-
313
- # print(tar_index_value_upper_top.shape, index_in.shape)
314
- return {
315
- "tar_pose_jaw": tar_pose_jaw,
316
- "tar_pose_face": tar_pose_face,
317
- "tar_pose_upper": tar_pose_upper,
318
- "tar_pose_lower": tar_pose_lower,
319
- "tar_pose_hands": tar_pose_hands,
320
- 'tar_pose_leg': tar_pose_leg,
321
- "in_audio": in_audio,
322
- "in_word": in_word,
323
- "tar_trans": tar_trans,
324
- "tar_exps": tar_exps,
325
- "tar_beta": tar_beta,
326
- "tar_pose": tar_pose,
327
- "tar4dis": tar4dis,
328
- "latent_face_top": latent_face_top,
329
- "latent_upper_top": latent_upper_top,
330
- "latent_hands_top": latent_hands_top,
331
- "latent_lower_top": latent_lower_top,
332
- "latent_in": latent_in,
333
- "tar_id": tar_id,
334
- "latent_all": latent_all,
335
- "tar_pose_6d": tar_pose_6d,
336
- "tar_contact": tar_contact,
337
- "style_feature":style_feature,
338
- }
339
-
340
- def _g_test(self, loaded_data):
341
- sample_fn = self.diffusion.p_sample_loop
342
- if self.args.use_ddim:
343
- sample_fn = self.diffusion.ddim_sample_loop
344
- mode = 'test'
345
- bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints
346
- tar_pose = loaded_data["tar_pose"]
347
- tar_beta = loaded_data["tar_beta"]
348
- tar_exps = loaded_data["tar_exps"]
349
- tar_contact = loaded_data["tar_contact"]
350
- tar_trans = loaded_data["tar_trans"]
351
- in_word = loaded_data["in_word"]
352
- in_audio = loaded_data["in_audio"]
353
- in_x0 = loaded_data['latent_in']
354
- in_seed = loaded_data['latent_in']
355
-
356
- remain = n%8
357
- if remain != 0:
358
- tar_pose = tar_pose[:, :-remain, :]
359
- tar_beta = tar_beta[:, :-remain, :]
360
- tar_trans = tar_trans[:, :-remain, :]
361
- in_word = in_word[:, :-remain]
362
- tar_exps = tar_exps[:, :-remain, :]
363
- tar_contact = tar_contact[:, :-remain, :]
364
- in_x0 = in_x0[:, :in_x0.shape[1]-(remain//self.args.vqvae_squeeze_scale), :]
365
- in_seed = in_seed[:, :in_x0.shape[1]-(remain//self.args.vqvae_squeeze_scale), :]
366
- n = n - remain
367
-
368
- tar_pose_jaw = tar_pose[:, :, 66:69]
369
- tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3))
370
- tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6)
371
- tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2)
372
-
373
- tar_pose_hands = tar_pose[:, :, 25*3:55*3]
374
- tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3))
375
- tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6)
376
-
377
- tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)]
378
- tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3))
379
- tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6)
380
-
381
- tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)]
382
- tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3))
383
- tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6)
384
- tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2)
385
-
386
- tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3))
387
- tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6)
388
- latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1)
389
-
390
- rec_all_face = []
391
- rec_all_upper = []
392
- rec_all_lower = []
393
- rec_all_hands = []
394
- vqvae_squeeze_scale = self.args.vqvae_squeeze_scale
395
- roundt = (n - self.args.pre_frames * vqvae_squeeze_scale) // (self.args.pose_length - self.args.pre_frames * vqvae_squeeze_scale)
396
- remain = (n - self.args.pre_frames * vqvae_squeeze_scale) % (self.args.pose_length - self.args.pre_frames * vqvae_squeeze_scale)
397
- round_l = self.args.pose_length - self.args.pre_frames * vqvae_squeeze_scale
398
-
399
-
400
- for i in range(0, roundt):
401
- in_word_tmp = in_word[:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames * vqvae_squeeze_scale]
402
-
403
- in_audio_tmp = in_audio[:, i*(16000//30*round_l):(i+1)*(16000//30*round_l)+16000//30*self.args.pre_frames * vqvae_squeeze_scale]
404
- in_id_tmp = loaded_data['tar_id'][:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames]
405
- in_seed_tmp = in_seed[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+self.args.pre_frames]
406
- in_x0_tmp = in_x0[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+self.args.pre_frames]
407
- mask_val = torch.ones(bs, self.args.pose_length, self.args.pose_dims+3+4).float().cuda()
408
- mask_val[:, :self.args.pre_frames, :] = 0.0
409
- if i == 0:
410
- in_seed_tmp = in_seed_tmp[:, :self.args.pre_frames, :]
411
- else:
412
- in_seed_tmp = last_sample[:, -self.args.pre_frames:, :]
413
-
414
- cond_ = {'y':{}}
415
- cond_['y']['audio'] = in_audio_tmp
416
- cond_['y']['word'] = in_word_tmp
417
- cond_['y']['id'] = in_id_tmp
418
- cond_['y']['seed'] =in_seed_tmp
419
- cond_['y']['mask'] = (torch.zeros([self.args.batch_size, 1, 1, self.args.pose_length]) < 1).cuda()
420
-
421
-
422
-
423
- cond_['y']['style_feature'] = torch.zeros([bs, 512]).cuda()
424
-
425
- shape_ = (bs, 1536, 1, 32)
426
- sample = sample_fn(
427
- self.model,
428
- shape_,
429
- clip_denoised=False,
430
- model_kwargs=cond_,
431
- skip_timesteps=0,
432
- init_image=None,
433
- progress=True,
434
- dump_steps=None,
435
- noise=None,
436
- const_noise=False,
437
- )
438
- sample = sample.squeeze().permute(1,0).unsqueeze(0)
439
-
440
- last_sample = sample.clone()
441
-
442
- rec_latent_upper = sample[...,:512]
443
- rec_latent_hands = sample[...,512:1024]
444
- rec_latent_lower = sample[...,1024:1536]
445
-
446
-
447
-
448
- if i == 0:
449
- rec_all_upper.append(rec_latent_upper)
450
- rec_all_hands.append(rec_latent_hands)
451
- rec_all_lower.append(rec_latent_lower)
452
- else:
453
- rec_all_upper.append(rec_latent_upper[:, self.args.pre_frames:])
454
- rec_all_hands.append(rec_latent_hands[:, self.args.pre_frames:])
455
- rec_all_lower.append(rec_latent_lower[:, self.args.pre_frames:])
456
-
457
- rec_all_upper = torch.cat(rec_all_upper, dim=1) * self.vqvae_latent_scale
458
- rec_all_hands = torch.cat(rec_all_hands, dim=1) * self.vqvae_latent_scale
459
- rec_all_lower = torch.cat(rec_all_lower, dim=1) * self.vqvae_latent_scale
460
-
461
- rec_upper = self.vq_model_upper.latent2origin(rec_all_upper)[0]
462
- rec_hands = self.vq_model_hands.latent2origin(rec_all_hands)[0]
463
- rec_lower = self.vq_model_lower.latent2origin(rec_all_lower)[0]
464
-
465
-
466
- if self.use_trans:
467
- rec_trans_v = rec_lower[...,-3:]
468
- rec_trans_v = rec_trans_v * self.trans_std + self.trans_mean
469
- rec_trans = torch.zeros_like(rec_trans_v)
470
- rec_trans = torch.cumsum(rec_trans_v, dim=-2)
471
- rec_trans[...,1]=rec_trans_v[...,1]
472
- rec_lower = rec_lower[...,:-3]
473
-
474
- if self.args.pose_norm:
475
- rec_upper = rec_upper * self.std_upper + self.mean_upper
476
- rec_hands = rec_hands * self.std_hands + self.mean_hands
477
- rec_lower = rec_lower * self.std_lower + self.mean_lower
478
-
479
-
480
-
481
-
482
- n = n - remain
483
- tar_pose = tar_pose[:, :n, :]
484
- tar_exps = tar_exps[:, :n, :]
485
- tar_trans = tar_trans[:, :n, :]
486
- tar_beta = tar_beta[:, :n, :]
487
-
488
-
489
- rec_exps = tar_exps
490
- #rec_pose_jaw = rec_face[:, :, :6]
491
- rec_pose_legs = rec_lower[:, :, :54]
492
- bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1]
493
- rec_pose_upper = rec_upper.reshape(bs, n, 13, 6)
494
- rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)#
495
- rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3)
496
- rec_pose_upper_recover = self.inverse_selection_tensor(rec_pose_upper, self.joint_mask_upper, bs*n)
497
- rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6)
498
- rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower)
499
- rec_lower2global = rc.matrix_to_rotation_6d(rec_pose_lower.clone()).reshape(bs, n, 9*6)
500
- rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3)
501
- rec_pose_lower_recover = self.inverse_selection_tensor(rec_pose_lower, self.joint_mask_lower, bs*n)
502
- rec_pose_hands = rec_hands.reshape(bs, n, 30, 6)
503
- rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands)
504
- rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3)
505
- rec_pose_hands_recover = self.inverse_selection_tensor(rec_pose_hands, self.joint_mask_hands, bs*n)
506
- rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover
507
- rec_pose[:, 66:69] = tar_pose.reshape(bs*n, 55*3)[:, 66:69]
508
-
509
- rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs*n, j, 3))
510
- rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6)
511
- tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs*n, j, 3))
512
- tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
513
-
514
- return {
515
- 'rec_pose': rec_pose,
516
- 'rec_trans': rec_trans,
517
- 'tar_pose': tar_pose,
518
- 'tar_exps': tar_exps,
519
- 'tar_beta': tar_beta,
520
- 'tar_trans': tar_trans,
521
- 'rec_exps': rec_exps,
522
- }
523
-
524
-
525
- def _create_cuda_model(self):
526
- args = self.args
527
- other_tools.load_checkpoints(self.model, args.test_ckpt, args.g_name)
528
- args.num_quantizers = 6
529
- args.shared_codebook = False
530
- args.quantize_dropout_prob = 0.2
531
- args.mu = 0.99
532
-
533
- args.nb_code = 512
534
- args.code_dim = 512
535
- args.code_dim = 512
536
- args.down_t = 2
537
- args.stride_t = 2
538
- args.width = 512
539
- args.depth = 3
540
- args.dilation_growth_rate = 3
541
- args.vq_act = "relu"
542
- args.vq_norm = None
543
-
544
- dim_pose = 78
545
- args.body_part = "upper"
546
- self.vq_model_upper = RVQVAE(args,
547
- dim_pose,
548
- args.nb_code,
549
- args.code_dim,
550
- args.code_dim,
551
- args.down_t,
552
- args.stride_t,
553
- args.width,
554
- args.depth,
555
- args.dilation_growth_rate,
556
- args.vq_act,
557
- args.vq_norm)
558
-
559
- dim_pose = 180
560
- args.body_part = "hands"
561
- self.vq_model_hands = RVQVAE(args,
562
- dim_pose,
563
- args.nb_code,
564
- args.code_dim,
565
- args.code_dim,
566
- args.down_t,
567
- args.stride_t,
568
- args.width,
569
- args.depth,
570
- args.dilation_growth_rate,
571
- args.vq_act,
572
- args.vq_norm)
573
-
574
- dim_pose = 54
575
- if args.use_trans:
576
- dim_pose = 57
577
- self.args.vqvae_lower_path = self.args.vqvae_lower_trans_path
578
- args.body_part = "lower"
579
- self.vq_model_lower = RVQVAE(args,
580
- dim_pose,
581
- args.nb_code,
582
- args.code_dim,
583
- args.code_dim,
584
- args.down_t,
585
- args.stride_t,
586
- args.width,
587
- args.depth,
588
- args.dilation_growth_rate,
589
- args.vq_act,
590
- args.vq_norm)
591
-
592
- self.vq_model_upper.load_state_dict(torch.load(self.args.vqvae_upper_path)['net'])
593
- self.vq_model_hands.load_state_dict(torch.load(self.args.vqvae_hands_path)['net'])
594
- self.vq_model_lower.load_state_dict(torch.load(self.args.vqvae_lower_path)['net'])
595
-
596
- self.vqvae_latent_scale = self.args.vqvae_latent_scale
597
-
598
- self.vq_model_upper.eval().to(self.rank)
599
- self.vq_model_hands.eval().to(self.rank)
600
- self.vq_model_lower.eval().to(self.rank)
601
-
602
- self.model = self.model.cuda()
603
- self.model.eval()
604
-
605
- self.mean_upper = torch.from_numpy(self.mean_upper).cuda()
606
- self.mean_hands = torch.from_numpy(self.mean_hands).cuda()
607
- self.mean_lower = torch.from_numpy(self.mean_lower).cuda()
608
- self.std_upper = torch.from_numpy(self.std_upper).cuda()
609
- self.std_hands = torch.from_numpy(self.std_hands).cuda()
610
- self.std_lower = torch.from_numpy(self.std_lower).cuda()
611
- self.trans_mean = torch.from_numpy(self.trans_mean).cuda()
612
- self.trans_std = torch.from_numpy(self.trans_std).cuda()
613
-
614
- @spaces.GPU(duration=149)
615
- def _warp(self, batch_data):
616
- self._create_cuda_model()
617
-
618
-
619
- loaded_data = self._load_data(batch_data)
620
- net_out = self._g_test(loaded_data)
621
- return net_out
622
 
623
  def test_demo(self, epoch):
624
  '''
@@ -644,7 +267,7 @@ class BaseTrainer(object):
644
  for its, batch_data in enumerate(self.test_loader):
645
  # loaded_data = self._load_data(batch_data)
646
  # net_out = self._g_test(loaded_data)
647
- net_out = self._warp(batch_data)
648
  tar_pose = net_out['tar_pose']
649
  rec_pose = net_out['rec_pose']
650
  tar_exps = net_out['tar_exps']
@@ -708,7 +331,402 @@ class BaseTrainer(object):
708
  end_time = time.time() - start_time
709
  logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion")
710
  return result
711
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  @logger.catch
713
  def syntalker(audio_path,sample_stratege):
714
  args = config.parse_args()
 
242
  original_shape_t[i, selected_indices] = filtered_t[i]
243
  return original_shape_t
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  def test_demo(self, epoch):
247
  '''
 
267
  for its, batch_data in enumerate(self.test_loader):
268
  # loaded_data = self._load_data(batch_data)
269
  # net_out = self._g_test(loaded_data)
270
+ net_out = _warp(self.args,self.model, batch_data,self.joints,self.joint_mask_upper,self.joint_mask_hands,self.joint_mask_lower,self.use_trans,self.diffusion)
271
  tar_pose = net_out['tar_pose']
272
  rec_pose = net_out['rec_pose']
273
  tar_exps = net_out['tar_exps']
 
331
  end_time = time.time() - start_time
332
  logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion")
333
  return result
334
+
335
+
336
+ @spaces.GPU(duration=149)
337
+ def _warp(args,model, batch_data,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,use_trans,diffusion):
338
+ args,model,vq_model_upper,vq_model_hands,vq_model_lower,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vqvae_latent_scale=_warp_create_cuda_model(args,model)
339
+
340
+
341
+ loaded_data = _warp_load_data(
342
+ batch_data,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,args,use_trans,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vq_model_upper,vq_model_hands,vq_model_lower
343
+ )
344
+ net_out = _warp_g_test(loaded_data,diffusion,args,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,model,vqvae_latent_scale,vq_model_upper,vq_model_hands,vq_model_lower,use_trans,trans_std,trans_mean,std_upper,std_hands,std_lower,mean_upper,mean_hands,mean_lower)
345
+ return net_out
346
+
347
+ def _warp_inverse_selection_tensor(filtered_t, selection_array, n):
348
+ selection_array = torch.from_numpy(selection_array).cuda()
349
+ original_shape_t = torch.zeros((n, 165)).cuda()
350
+ selected_indices = torch.where(selection_array == 1)[0]
351
+ for i in range(n):
352
+ original_shape_t[i, selected_indices] = filtered_t[i]
353
+ return original_shape_t
354
+
355
+ def _warp_g_test(loaded_data,diffusion,args,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,model,vqvae_latent_scale,vq_model_upper,vq_model_hands,vq_model_lower,use_trans,trans_std,trans_mean,std_upper,std_hands,std_lower,mean_upper,mean_hands,mean_lower):
356
+ sample_fn = diffusion.p_sample_loop
357
+ if args.use_ddim:
358
+ sample_fn = diffusion.ddim_sample_loop
359
+ mode = 'test'
360
+ bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], joints
361
+ tar_pose = loaded_data["tar_pose"]
362
+ tar_beta = loaded_data["tar_beta"]
363
+ tar_exps = loaded_data["tar_exps"]
364
+ tar_contact = loaded_data["tar_contact"]
365
+ tar_trans = loaded_data["tar_trans"]
366
+ in_word = loaded_data["in_word"]
367
+ in_audio = loaded_data["in_audio"]
368
+ in_x0 = loaded_data['latent_in']
369
+ in_seed = loaded_data['latent_in']
370
+
371
+ remain = n%8
372
+ if remain != 0:
373
+ tar_pose = tar_pose[:, :-remain, :]
374
+ tar_beta = tar_beta[:, :-remain, :]
375
+ tar_trans = tar_trans[:, :-remain, :]
376
+ in_word = in_word[:, :-remain]
377
+ tar_exps = tar_exps[:, :-remain, :]
378
+ tar_contact = tar_contact[:, :-remain, :]
379
+ in_x0 = in_x0[:, :in_x0.shape[1]-(remain//args.vqvae_squeeze_scale), :]
380
+ in_seed = in_seed[:, :in_x0.shape[1]-(remain//args.vqvae_squeeze_scale), :]
381
+ n = n - remain
382
+
383
+ tar_pose_jaw = tar_pose[:, :, 66:69]
384
+ tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3))
385
+ tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6)
386
+ tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2)
387
+
388
+ tar_pose_hands = tar_pose[:, :, 25*3:55*3]
389
+ tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3))
390
+ tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6)
391
+
392
+ tar_pose_upper = tar_pose[:, :, joint_mask_upper.astype(bool)]
393
+ tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3))
394
+ tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6)
395
+
396
+ tar_pose_leg = tar_pose[:, :, joint_mask_lower.astype(bool)]
397
+ tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3))
398
+ tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6)
399
+ tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2)
400
+
401
+ tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3))
402
+ tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6)
403
+ latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1)
404
+
405
+ rec_all_face = []
406
+ rec_all_upper = []
407
+ rec_all_lower = []
408
+ rec_all_hands = []
409
+ vqvae_squeeze_scale = args.vqvae_squeeze_scale
410
+ roundt = (n - args.pre_frames * vqvae_squeeze_scale) // (args.pose_length - args.pre_frames * vqvae_squeeze_scale)
411
+ remain = (n - args.pre_frames * vqvae_squeeze_scale) % (args.pose_length - args.pre_frames * vqvae_squeeze_scale)
412
+ round_l = args.pose_length - args.pre_frames * vqvae_squeeze_scale
413
+
414
+
415
+ for i in range(0, roundt):
416
+ in_word_tmp = in_word[:, i*(round_l):(i+1)*(round_l)+args.pre_frames * vqvae_squeeze_scale]
417
+
418
+ in_audio_tmp = in_audio[:, i*(16000//30*round_l):(i+1)*(16000//30*round_l)+16000//30*args.pre_frames * vqvae_squeeze_scale]
419
+ in_id_tmp = loaded_data['tar_id'][:, i*(round_l):(i+1)*(round_l)+args.pre_frames]
420
+ in_seed_tmp = in_seed[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+args.pre_frames]
421
+ in_x0_tmp = in_x0[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+args.pre_frames]
422
+ mask_val = torch.ones(bs, args.pose_length, args.pose_dims+3+4).float().cuda()
423
+ mask_val[:, :args.pre_frames, :] = 0.0
424
+ if i == 0:
425
+ in_seed_tmp = in_seed_tmp[:, :args.pre_frames, :]
426
+ else:
427
+ in_seed_tmp = last_sample[:, -args.pre_frames:, :]
428
+
429
+ cond_ = {'y':{}}
430
+ cond_['y']['audio'] = in_audio_tmp
431
+ cond_['y']['word'] = in_word_tmp
432
+ cond_['y']['id'] = in_id_tmp
433
+ cond_['y']['seed'] =in_seed_tmp
434
+ cond_['y']['mask'] = (torch.zeros([args.batch_size, 1, 1, args.pose_length]) < 1).cuda()
435
+
436
+
437
+
438
+ cond_['y']['style_feature'] = torch.zeros([bs, 512]).cuda()
439
+
440
+ shape_ = (bs, 1536, 1, 32)
441
+ sample = sample_fn(
442
+ model,
443
+ shape_,
444
+ clip_denoised=False,
445
+ model_kwargs=cond_,
446
+ skip_timesteps=0,
447
+ init_image=None,
448
+ progress=True,
449
+ dump_steps=None,
450
+ noise=None,
451
+ const_noise=False,
452
+ )
453
+ sample = sample.squeeze().permute(1,0).unsqueeze(0)
454
+
455
+ last_sample = sample.clone()
456
+
457
+ rec_latent_upper = sample[...,:512]
458
+ rec_latent_hands = sample[...,512:1024]
459
+ rec_latent_lower = sample[...,1024:1536]
460
+
461
+
462
+
463
+ if i == 0:
464
+ rec_all_upper.append(rec_latent_upper)
465
+ rec_all_hands.append(rec_latent_hands)
466
+ rec_all_lower.append(rec_latent_lower)
467
+ else:
468
+ rec_all_upper.append(rec_latent_upper[:, args.pre_frames:])
469
+ rec_all_hands.append(rec_latent_hands[:, args.pre_frames:])
470
+ rec_all_lower.append(rec_latent_lower[:, args.pre_frames:])
471
+
472
+ rec_all_upper = torch.cat(rec_all_upper, dim=1) * vqvae_latent_scale
473
+ rec_all_hands = torch.cat(rec_all_hands, dim=1) * vqvae_latent_scale
474
+ rec_all_lower = torch.cat(rec_all_lower, dim=1) * vqvae_latent_scale
475
+
476
+ rec_upper = vq_model_upper.latent2origin(rec_all_upper)[0]
477
+ rec_hands = vq_model_hands.latent2origin(rec_all_hands)[0]
478
+ rec_lower = vq_model_lower.latent2origin(rec_all_lower)[0]
479
+
480
+
481
+ if use_trans:
482
+ rec_trans_v = rec_lower[...,-3:]
483
+ rec_trans_v = rec_trans_v * trans_std + trans_mean
484
+ rec_trans = torch.zeros_like(rec_trans_v)
485
+ rec_trans = torch.cumsum(rec_trans_v, dim=-2)
486
+ rec_trans[...,1]=rec_trans_v[...,1]
487
+ rec_lower = rec_lower[...,:-3]
488
+
489
+ if args.pose_norm:
490
+ rec_upper = rec_upper * std_upper + mean_upper
491
+ rec_hands = rec_hands * std_hands + mean_hands
492
+ rec_lower = rec_lower * std_lower + mean_lower
493
+
494
+
495
+
496
+
497
+ n = n - remain
498
+ tar_pose = tar_pose[:, :n, :]
499
+ tar_exps = tar_exps[:, :n, :]
500
+ tar_trans = tar_trans[:, :n, :]
501
+ tar_beta = tar_beta[:, :n, :]
502
+
503
+
504
+ rec_exps = tar_exps
505
+ #rec_pose_jaw = rec_face[:, :, :6]
506
+ rec_pose_legs = rec_lower[:, :, :54]
507
+ bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1]
508
+ rec_pose_upper = rec_upper.reshape(bs, n, 13, 6)
509
+ rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)#
510
+ rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3)
511
+ rec_pose_upper_recover = _warp_inverse_selection_tensor(rec_pose_upper, joint_mask_upper, bs*n)
512
+ rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6)
513
+ rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower)
514
+ rec_lower2global = rc.matrix_to_rotation_6d(rec_pose_lower.clone()).reshape(bs, n, 9*6)
515
+ rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3)
516
+ rec_pose_lower_recover = _warp_inverse_selection_tensor(rec_pose_lower, joint_mask_lower, bs*n)
517
+ rec_pose_hands = rec_hands.reshape(bs, n, 30, 6)
518
+ rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands)
519
+ rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3)
520
+ rec_pose_hands_recover = _warp_inverse_selection_tensor(rec_pose_hands, joint_mask_hands, bs*n)
521
+ rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover
522
+ rec_pose[:, 66:69] = tar_pose.reshape(bs*n, 55*3)[:, 66:69]
523
+
524
+ rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs*n, j, 3))
525
+ rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6)
526
+ tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs*n, j, 3))
527
+ tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
528
+
529
+ return {
530
+ 'rec_pose': rec_pose,
531
+ 'rec_trans': rec_trans,
532
+ 'tar_pose': tar_pose,
533
+ 'tar_exps': tar_exps,
534
+ 'tar_beta': tar_beta,
535
+ 'tar_trans': tar_trans,
536
+ 'rec_exps': rec_exps,
537
+ }
538
+
539
+
540
+
541
+ def _warp_load_data(dict_data,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,args,use_trans,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vq_model_upper,vq_model_hands,vq_model_lower):
542
+ tar_pose_raw = dict_data["pose"]
543
+ tar_pose = tar_pose_raw[:, :, :165].cuda()
544
+ tar_contact = tar_pose_raw[:, :, 165:169].cuda()
545
+ tar_trans = dict_data["trans"].cuda()
546
+ tar_trans_v = dict_data["trans_v"].cuda()
547
+ tar_exps = dict_data["facial"].cuda()
548
+ in_audio = dict_data["audio"].cuda()
549
+ in_word = dict_data["word"].cuda()
550
+ tar_beta = dict_data["beta"].cuda()
551
+ tar_id = dict_data["id"].cuda().long()
552
+ bs, n, j = tar_pose.shape[0], tar_pose.shape[1], joints
553
+
554
+ tar_pose_jaw = tar_pose[:, :, 66:69]
555
+ tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3))
556
+ tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6)
557
+ tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2)
558
+
559
+ tar_pose_hands = tar_pose[:, :, 25*3:55*3]
560
+ tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3))
561
+ tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6)
562
+
563
+ tar_pose_upper = tar_pose[:, :, joint_mask_upper.astype(bool)]
564
+ tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3))
565
+ tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6)
566
+
567
+ tar_pose_leg = tar_pose[:, :, joint_mask_lower.astype(bool)]
568
+ tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3))
569
+ tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6)
570
+
571
+ tar_pose_lower = tar_pose_leg
572
+
573
+
574
+ tar4dis = torch.cat([tar_pose_jaw, tar_pose_upper, tar_pose_hands, tar_pose_leg], dim=2)
575
+
576
+
577
+ if args.pose_norm:
578
+ tar_pose_upper = (tar_pose_upper - mean_upper) / std_upper
579
+ tar_pose_hands = (tar_pose_hands - mean_hands) / std_hands
580
+ tar_pose_lower = (tar_pose_lower - mean_lower) / std_lower
581
+
582
+ if use_trans:
583
+ tar_trans_v = (tar_trans_v - trans_mean)/trans_std
584
+ tar_pose_lower = torch.cat([tar_pose_lower,tar_trans_v], dim=-1)
585
+
586
+ latent_face_top = None#self.vq_model_face.map2latent(tar_pose_face) # bs*n/4
587
+ latent_upper_top = vq_model_upper.map2latent(tar_pose_upper)
588
+ latent_hands_top = vq_model_hands.map2latent(tar_pose_hands)
589
+ latent_lower_top = vq_model_lower.map2latent(tar_pose_lower)
590
+
591
+ latent_in = torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2)/args.vqvae_latent_scale
592
+
593
+
594
+ tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3))
595
+ tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6)
596
+ latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1)
597
+ style_feature = None
598
+ if args.use_motionclip:
599
+ motionclip_feat = tar_pose_6d[...,:22*6]
600
+ batch = {}
601
+ bs,seq,feat = motionclip_feat.shape
602
+ batch['x']=motionclip_feat.permute(0,2,1).contiguous()
603
+ batch['y']=torch.zeros(bs).int().cuda()
604
+ batch['mask']=torch.ones([bs,seq]).bool().cuda()
605
+ style_feature = motionclip.encoder(batch)['mu'].detach().float()
606
+
607
+
608
+
609
+ # print(tar_index_value_upper_top.shape, index_in.shape)
610
+ return {
611
+ "tar_pose_jaw": tar_pose_jaw,
612
+ "tar_pose_face": tar_pose_face,
613
+ "tar_pose_upper": tar_pose_upper,
614
+ "tar_pose_lower": tar_pose_lower,
615
+ "tar_pose_hands": tar_pose_hands,
616
+ 'tar_pose_leg': tar_pose_leg,
617
+ "in_audio": in_audio,
618
+ "in_word": in_word,
619
+ "tar_trans": tar_trans,
620
+ "tar_exps": tar_exps,
621
+ "tar_beta": tar_beta,
622
+ "tar_pose": tar_pose,
623
+ "tar4dis": tar4dis,
624
+ "latent_face_top": latent_face_top,
625
+ "latent_upper_top": latent_upper_top,
626
+ "latent_hands_top": latent_hands_top,
627
+ "latent_lower_top": latent_lower_top,
628
+ "latent_in": latent_in,
629
+ "tar_id": tar_id,
630
+ "latent_all": latent_all,
631
+ "tar_pose_6d": tar_pose_6d,
632
+ "tar_contact": tar_contact,
633
+ "style_feature":style_feature,
634
+ }
635
+
636
+
637
+ def _warp_create_cuda_model(args,model):
638
+ args = args
639
+ other_tools.load_checkpoints(model, args.test_ckpt, args.g_name)
640
+ args.num_quantizers = 6
641
+ args.shared_codebook = False
642
+ args.quantize_dropout_prob = 0.2
643
+ args.mu = 0.99
644
+
645
+ args.nb_code = 512
646
+ args.code_dim = 512
647
+ args.code_dim = 512
648
+ args.down_t = 2
649
+ args.stride_t = 2
650
+ args.width = 512
651
+ args.depth = 3
652
+ args.dilation_growth_rate = 3
653
+ args.vq_act = "relu"
654
+ args.vq_norm = None
655
+
656
+ dim_pose = 78
657
+ args.body_part = "upper"
658
+ vq_model_upper = RVQVAE(args,
659
+ dim_pose,
660
+ args.nb_code,
661
+ args.code_dim,
662
+ args.code_dim,
663
+ args.down_t,
664
+ args.stride_t,
665
+ args.width,
666
+ args.depth,
667
+ args.dilation_growth_rate,
668
+ args.vq_act,
669
+ args.vq_norm)
670
+
671
+ dim_pose = 180
672
+ args.body_part = "hands"
673
+ vq_model_hands = RVQVAE(args,
674
+ dim_pose,
675
+ args.nb_code,
676
+ args.code_dim,
677
+ args.code_dim,
678
+ args.down_t,
679
+ args.stride_t,
680
+ args.width,
681
+ args.depth,
682
+ args.dilation_growth_rate,
683
+ args.vq_act,
684
+ args.vq_norm)
685
+
686
+ dim_pose = 54
687
+ if args.use_trans:
688
+ dim_pose = 57
689
+ args.vqvae_lower_path = args.vqvae_lower_trans_path
690
+ args.body_part = "lower"
691
+ vq_model_lower = RVQVAE(args,
692
+ dim_pose,
693
+ args.nb_code,
694
+ args.code_dim,
695
+ args.code_dim,
696
+ args.down_t,
697
+ args.stride_t,
698
+ args.width,
699
+ args.depth,
700
+ args.dilation_growth_rate,
701
+ args.vq_act,
702
+ args.vq_norm)
703
+
704
+ vq_model_upper.load_state_dict(torch.load(args.vqvae_upper_path)['net'])
705
+ vq_model_hands.load_state_dict(torch.load(args.vqvae_hands_path)['net'])
706
+ vq_model_lower.load_state_dict(torch.load(args.vqvae_lower_path)['net'])
707
+
708
+ vqvae_latent_scale = args.vqvae_latent_scale
709
+
710
+ vq_model_upper.eval().cuda()
711
+ vq_model_hands.eval().cuda()
712
+ vq_model_lower.eval().cuda()
713
+
714
+ model = model.cuda()
715
+ model.eval()
716
+
717
+ mean_upper = torch.from_numpy(mean_upper).cuda()
718
+ mean_hands = torch.from_numpy(mean_hands).cuda()
719
+ mean_lower = torch.from_numpy(mean_lower).cuda()
720
+ std_upper = torch.from_numpy(std_upper).cuda()
721
+ std_hands = torch.from_numpy(std_hands).cuda()
722
+ std_lower = torch.from_numpy(std_lower).cuda()
723
+ trans_mean = torch.from_numpy(trans_mean).cuda()
724
+ trans_std = torch.from_numpy(trans_std).cuda()
725
+
726
+ return args,model,vq_model_upper,vq_model_hands,vq_model_lower,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vqvae_latent_scale
727
+
728
+
729
+
730
  @logger.catch
731
  def syntalker(audio_path,sample_stratege):
732
  args = config.parse_args()