Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
f01632b
1
Parent(s):
279199b
add
Browse files
app.py
CHANGED
@@ -242,383 +242,6 @@ class BaseTrainer(object):
|
|
242 |
original_shape_t[i, selected_indices] = filtered_t[i]
|
243 |
return original_shape_t
|
244 |
|
245 |
-
def _load_data(self, dict_data):
|
246 |
-
tar_pose_raw = dict_data["pose"]
|
247 |
-
tar_pose = tar_pose_raw[:, :, :165].to(self.rank)
|
248 |
-
tar_contact = tar_pose_raw[:, :, 165:169].to(self.rank)
|
249 |
-
tar_trans = dict_data["trans"].to(self.rank)
|
250 |
-
tar_trans_v = dict_data["trans_v"].to(self.rank)
|
251 |
-
tar_exps = dict_data["facial"].to(self.rank)
|
252 |
-
in_audio = dict_data["audio"].to(self.rank)
|
253 |
-
in_word = dict_data["word"].to(self.rank)
|
254 |
-
tar_beta = dict_data["beta"].to(self.rank)
|
255 |
-
tar_id = dict_data["id"].to(self.rank).long()
|
256 |
-
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints
|
257 |
-
|
258 |
-
tar_pose_jaw = tar_pose[:, :, 66:69]
|
259 |
-
tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3))
|
260 |
-
tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6)
|
261 |
-
tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2)
|
262 |
-
|
263 |
-
tar_pose_hands = tar_pose[:, :, 25*3:55*3]
|
264 |
-
tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3))
|
265 |
-
tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6)
|
266 |
-
|
267 |
-
tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)]
|
268 |
-
tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3))
|
269 |
-
tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6)
|
270 |
-
|
271 |
-
tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)]
|
272 |
-
tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3))
|
273 |
-
tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6)
|
274 |
-
|
275 |
-
tar_pose_lower = tar_pose_leg
|
276 |
-
|
277 |
-
|
278 |
-
tar4dis = torch.cat([tar_pose_jaw, tar_pose_upper, tar_pose_hands, tar_pose_leg], dim=2)
|
279 |
-
|
280 |
-
|
281 |
-
if self.args.pose_norm:
|
282 |
-
tar_pose_upper = (tar_pose_upper - self.mean_upper) / self.std_upper
|
283 |
-
tar_pose_hands = (tar_pose_hands - self.mean_hands) / self.std_hands
|
284 |
-
tar_pose_lower = (tar_pose_lower - self.mean_lower) / self.std_lower
|
285 |
-
|
286 |
-
if self.use_trans:
|
287 |
-
tar_trans_v = (tar_trans_v - self.trans_mean)/self.trans_std
|
288 |
-
tar_pose_lower = torch.cat([tar_pose_lower,tar_trans_v], dim=-1)
|
289 |
-
|
290 |
-
latent_face_top = None#self.vq_model_face.map2latent(tar_pose_face) # bs*n/4
|
291 |
-
latent_upper_top = self.vq_model_upper.map2latent(tar_pose_upper)
|
292 |
-
latent_hands_top = self.vq_model_hands.map2latent(tar_pose_hands)
|
293 |
-
latent_lower_top = self.vq_model_lower.map2latent(tar_pose_lower)
|
294 |
-
|
295 |
-
latent_in = torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2)/self.args.vqvae_latent_scale
|
296 |
-
|
297 |
-
|
298 |
-
tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3))
|
299 |
-
tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6)
|
300 |
-
latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1)
|
301 |
-
style_feature = None
|
302 |
-
if self.args.use_motionclip:
|
303 |
-
motionclip_feat = tar_pose_6d[...,:22*6]
|
304 |
-
batch = {}
|
305 |
-
bs,seq,feat = motionclip_feat.shape
|
306 |
-
batch['x']=motionclip_feat.permute(0,2,1).contiguous()
|
307 |
-
batch['y']=torch.zeros(bs).int().cuda()
|
308 |
-
batch['mask']=torch.ones([bs,seq]).bool().cuda()
|
309 |
-
style_feature = self.motionclip.encoder(batch)['mu'].detach().float()
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
# print(tar_index_value_upper_top.shape, index_in.shape)
|
314 |
-
return {
|
315 |
-
"tar_pose_jaw": tar_pose_jaw,
|
316 |
-
"tar_pose_face": tar_pose_face,
|
317 |
-
"tar_pose_upper": tar_pose_upper,
|
318 |
-
"tar_pose_lower": tar_pose_lower,
|
319 |
-
"tar_pose_hands": tar_pose_hands,
|
320 |
-
'tar_pose_leg': tar_pose_leg,
|
321 |
-
"in_audio": in_audio,
|
322 |
-
"in_word": in_word,
|
323 |
-
"tar_trans": tar_trans,
|
324 |
-
"tar_exps": tar_exps,
|
325 |
-
"tar_beta": tar_beta,
|
326 |
-
"tar_pose": tar_pose,
|
327 |
-
"tar4dis": tar4dis,
|
328 |
-
"latent_face_top": latent_face_top,
|
329 |
-
"latent_upper_top": latent_upper_top,
|
330 |
-
"latent_hands_top": latent_hands_top,
|
331 |
-
"latent_lower_top": latent_lower_top,
|
332 |
-
"latent_in": latent_in,
|
333 |
-
"tar_id": tar_id,
|
334 |
-
"latent_all": latent_all,
|
335 |
-
"tar_pose_6d": tar_pose_6d,
|
336 |
-
"tar_contact": tar_contact,
|
337 |
-
"style_feature":style_feature,
|
338 |
-
}
|
339 |
-
|
340 |
-
def _g_test(self, loaded_data):
|
341 |
-
sample_fn = self.diffusion.p_sample_loop
|
342 |
-
if self.args.use_ddim:
|
343 |
-
sample_fn = self.diffusion.ddim_sample_loop
|
344 |
-
mode = 'test'
|
345 |
-
bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], self.joints
|
346 |
-
tar_pose = loaded_data["tar_pose"]
|
347 |
-
tar_beta = loaded_data["tar_beta"]
|
348 |
-
tar_exps = loaded_data["tar_exps"]
|
349 |
-
tar_contact = loaded_data["tar_contact"]
|
350 |
-
tar_trans = loaded_data["tar_trans"]
|
351 |
-
in_word = loaded_data["in_word"]
|
352 |
-
in_audio = loaded_data["in_audio"]
|
353 |
-
in_x0 = loaded_data['latent_in']
|
354 |
-
in_seed = loaded_data['latent_in']
|
355 |
-
|
356 |
-
remain = n%8
|
357 |
-
if remain != 0:
|
358 |
-
tar_pose = tar_pose[:, :-remain, :]
|
359 |
-
tar_beta = tar_beta[:, :-remain, :]
|
360 |
-
tar_trans = tar_trans[:, :-remain, :]
|
361 |
-
in_word = in_word[:, :-remain]
|
362 |
-
tar_exps = tar_exps[:, :-remain, :]
|
363 |
-
tar_contact = tar_contact[:, :-remain, :]
|
364 |
-
in_x0 = in_x0[:, :in_x0.shape[1]-(remain//self.args.vqvae_squeeze_scale), :]
|
365 |
-
in_seed = in_seed[:, :in_x0.shape[1]-(remain//self.args.vqvae_squeeze_scale), :]
|
366 |
-
n = n - remain
|
367 |
-
|
368 |
-
tar_pose_jaw = tar_pose[:, :, 66:69]
|
369 |
-
tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3))
|
370 |
-
tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6)
|
371 |
-
tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2)
|
372 |
-
|
373 |
-
tar_pose_hands = tar_pose[:, :, 25*3:55*3]
|
374 |
-
tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3))
|
375 |
-
tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6)
|
376 |
-
|
377 |
-
tar_pose_upper = tar_pose[:, :, self.joint_mask_upper.astype(bool)]
|
378 |
-
tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3))
|
379 |
-
tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6)
|
380 |
-
|
381 |
-
tar_pose_leg = tar_pose[:, :, self.joint_mask_lower.astype(bool)]
|
382 |
-
tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3))
|
383 |
-
tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6)
|
384 |
-
tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2)
|
385 |
-
|
386 |
-
tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3))
|
387 |
-
tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6)
|
388 |
-
latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1)
|
389 |
-
|
390 |
-
rec_all_face = []
|
391 |
-
rec_all_upper = []
|
392 |
-
rec_all_lower = []
|
393 |
-
rec_all_hands = []
|
394 |
-
vqvae_squeeze_scale = self.args.vqvae_squeeze_scale
|
395 |
-
roundt = (n - self.args.pre_frames * vqvae_squeeze_scale) // (self.args.pose_length - self.args.pre_frames * vqvae_squeeze_scale)
|
396 |
-
remain = (n - self.args.pre_frames * vqvae_squeeze_scale) % (self.args.pose_length - self.args.pre_frames * vqvae_squeeze_scale)
|
397 |
-
round_l = self.args.pose_length - self.args.pre_frames * vqvae_squeeze_scale
|
398 |
-
|
399 |
-
|
400 |
-
for i in range(0, roundt):
|
401 |
-
in_word_tmp = in_word[:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames * vqvae_squeeze_scale]
|
402 |
-
|
403 |
-
in_audio_tmp = in_audio[:, i*(16000//30*round_l):(i+1)*(16000//30*round_l)+16000//30*self.args.pre_frames * vqvae_squeeze_scale]
|
404 |
-
in_id_tmp = loaded_data['tar_id'][:, i*(round_l):(i+1)*(round_l)+self.args.pre_frames]
|
405 |
-
in_seed_tmp = in_seed[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+self.args.pre_frames]
|
406 |
-
in_x0_tmp = in_x0[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+self.args.pre_frames]
|
407 |
-
mask_val = torch.ones(bs, self.args.pose_length, self.args.pose_dims+3+4).float().cuda()
|
408 |
-
mask_val[:, :self.args.pre_frames, :] = 0.0
|
409 |
-
if i == 0:
|
410 |
-
in_seed_tmp = in_seed_tmp[:, :self.args.pre_frames, :]
|
411 |
-
else:
|
412 |
-
in_seed_tmp = last_sample[:, -self.args.pre_frames:, :]
|
413 |
-
|
414 |
-
cond_ = {'y':{}}
|
415 |
-
cond_['y']['audio'] = in_audio_tmp
|
416 |
-
cond_['y']['word'] = in_word_tmp
|
417 |
-
cond_['y']['id'] = in_id_tmp
|
418 |
-
cond_['y']['seed'] =in_seed_tmp
|
419 |
-
cond_['y']['mask'] = (torch.zeros([self.args.batch_size, 1, 1, self.args.pose_length]) < 1).cuda()
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
cond_['y']['style_feature'] = torch.zeros([bs, 512]).cuda()
|
424 |
-
|
425 |
-
shape_ = (bs, 1536, 1, 32)
|
426 |
-
sample = sample_fn(
|
427 |
-
self.model,
|
428 |
-
shape_,
|
429 |
-
clip_denoised=False,
|
430 |
-
model_kwargs=cond_,
|
431 |
-
skip_timesteps=0,
|
432 |
-
init_image=None,
|
433 |
-
progress=True,
|
434 |
-
dump_steps=None,
|
435 |
-
noise=None,
|
436 |
-
const_noise=False,
|
437 |
-
)
|
438 |
-
sample = sample.squeeze().permute(1,0).unsqueeze(0)
|
439 |
-
|
440 |
-
last_sample = sample.clone()
|
441 |
-
|
442 |
-
rec_latent_upper = sample[...,:512]
|
443 |
-
rec_latent_hands = sample[...,512:1024]
|
444 |
-
rec_latent_lower = sample[...,1024:1536]
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
if i == 0:
|
449 |
-
rec_all_upper.append(rec_latent_upper)
|
450 |
-
rec_all_hands.append(rec_latent_hands)
|
451 |
-
rec_all_lower.append(rec_latent_lower)
|
452 |
-
else:
|
453 |
-
rec_all_upper.append(rec_latent_upper[:, self.args.pre_frames:])
|
454 |
-
rec_all_hands.append(rec_latent_hands[:, self.args.pre_frames:])
|
455 |
-
rec_all_lower.append(rec_latent_lower[:, self.args.pre_frames:])
|
456 |
-
|
457 |
-
rec_all_upper = torch.cat(rec_all_upper, dim=1) * self.vqvae_latent_scale
|
458 |
-
rec_all_hands = torch.cat(rec_all_hands, dim=1) * self.vqvae_latent_scale
|
459 |
-
rec_all_lower = torch.cat(rec_all_lower, dim=1) * self.vqvae_latent_scale
|
460 |
-
|
461 |
-
rec_upper = self.vq_model_upper.latent2origin(rec_all_upper)[0]
|
462 |
-
rec_hands = self.vq_model_hands.latent2origin(rec_all_hands)[0]
|
463 |
-
rec_lower = self.vq_model_lower.latent2origin(rec_all_lower)[0]
|
464 |
-
|
465 |
-
|
466 |
-
if self.use_trans:
|
467 |
-
rec_trans_v = rec_lower[...,-3:]
|
468 |
-
rec_trans_v = rec_trans_v * self.trans_std + self.trans_mean
|
469 |
-
rec_trans = torch.zeros_like(rec_trans_v)
|
470 |
-
rec_trans = torch.cumsum(rec_trans_v, dim=-2)
|
471 |
-
rec_trans[...,1]=rec_trans_v[...,1]
|
472 |
-
rec_lower = rec_lower[...,:-3]
|
473 |
-
|
474 |
-
if self.args.pose_norm:
|
475 |
-
rec_upper = rec_upper * self.std_upper + self.mean_upper
|
476 |
-
rec_hands = rec_hands * self.std_hands + self.mean_hands
|
477 |
-
rec_lower = rec_lower * self.std_lower + self.mean_lower
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
n = n - remain
|
483 |
-
tar_pose = tar_pose[:, :n, :]
|
484 |
-
tar_exps = tar_exps[:, :n, :]
|
485 |
-
tar_trans = tar_trans[:, :n, :]
|
486 |
-
tar_beta = tar_beta[:, :n, :]
|
487 |
-
|
488 |
-
|
489 |
-
rec_exps = tar_exps
|
490 |
-
#rec_pose_jaw = rec_face[:, :, :6]
|
491 |
-
rec_pose_legs = rec_lower[:, :, :54]
|
492 |
-
bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1]
|
493 |
-
rec_pose_upper = rec_upper.reshape(bs, n, 13, 6)
|
494 |
-
rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)#
|
495 |
-
rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3)
|
496 |
-
rec_pose_upper_recover = self.inverse_selection_tensor(rec_pose_upper, self.joint_mask_upper, bs*n)
|
497 |
-
rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6)
|
498 |
-
rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower)
|
499 |
-
rec_lower2global = rc.matrix_to_rotation_6d(rec_pose_lower.clone()).reshape(bs, n, 9*6)
|
500 |
-
rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3)
|
501 |
-
rec_pose_lower_recover = self.inverse_selection_tensor(rec_pose_lower, self.joint_mask_lower, bs*n)
|
502 |
-
rec_pose_hands = rec_hands.reshape(bs, n, 30, 6)
|
503 |
-
rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands)
|
504 |
-
rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3)
|
505 |
-
rec_pose_hands_recover = self.inverse_selection_tensor(rec_pose_hands, self.joint_mask_hands, bs*n)
|
506 |
-
rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover
|
507 |
-
rec_pose[:, 66:69] = tar_pose.reshape(bs*n, 55*3)[:, 66:69]
|
508 |
-
|
509 |
-
rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs*n, j, 3))
|
510 |
-
rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6)
|
511 |
-
tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs*n, j, 3))
|
512 |
-
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
|
513 |
-
|
514 |
-
return {
|
515 |
-
'rec_pose': rec_pose,
|
516 |
-
'rec_trans': rec_trans,
|
517 |
-
'tar_pose': tar_pose,
|
518 |
-
'tar_exps': tar_exps,
|
519 |
-
'tar_beta': tar_beta,
|
520 |
-
'tar_trans': tar_trans,
|
521 |
-
'rec_exps': rec_exps,
|
522 |
-
}
|
523 |
-
|
524 |
-
|
525 |
-
def _create_cuda_model(self):
|
526 |
-
args = self.args
|
527 |
-
other_tools.load_checkpoints(self.model, args.test_ckpt, args.g_name)
|
528 |
-
args.num_quantizers = 6
|
529 |
-
args.shared_codebook = False
|
530 |
-
args.quantize_dropout_prob = 0.2
|
531 |
-
args.mu = 0.99
|
532 |
-
|
533 |
-
args.nb_code = 512
|
534 |
-
args.code_dim = 512
|
535 |
-
args.code_dim = 512
|
536 |
-
args.down_t = 2
|
537 |
-
args.stride_t = 2
|
538 |
-
args.width = 512
|
539 |
-
args.depth = 3
|
540 |
-
args.dilation_growth_rate = 3
|
541 |
-
args.vq_act = "relu"
|
542 |
-
args.vq_norm = None
|
543 |
-
|
544 |
-
dim_pose = 78
|
545 |
-
args.body_part = "upper"
|
546 |
-
self.vq_model_upper = RVQVAE(args,
|
547 |
-
dim_pose,
|
548 |
-
args.nb_code,
|
549 |
-
args.code_dim,
|
550 |
-
args.code_dim,
|
551 |
-
args.down_t,
|
552 |
-
args.stride_t,
|
553 |
-
args.width,
|
554 |
-
args.depth,
|
555 |
-
args.dilation_growth_rate,
|
556 |
-
args.vq_act,
|
557 |
-
args.vq_norm)
|
558 |
-
|
559 |
-
dim_pose = 180
|
560 |
-
args.body_part = "hands"
|
561 |
-
self.vq_model_hands = RVQVAE(args,
|
562 |
-
dim_pose,
|
563 |
-
args.nb_code,
|
564 |
-
args.code_dim,
|
565 |
-
args.code_dim,
|
566 |
-
args.down_t,
|
567 |
-
args.stride_t,
|
568 |
-
args.width,
|
569 |
-
args.depth,
|
570 |
-
args.dilation_growth_rate,
|
571 |
-
args.vq_act,
|
572 |
-
args.vq_norm)
|
573 |
-
|
574 |
-
dim_pose = 54
|
575 |
-
if args.use_trans:
|
576 |
-
dim_pose = 57
|
577 |
-
self.args.vqvae_lower_path = self.args.vqvae_lower_trans_path
|
578 |
-
args.body_part = "lower"
|
579 |
-
self.vq_model_lower = RVQVAE(args,
|
580 |
-
dim_pose,
|
581 |
-
args.nb_code,
|
582 |
-
args.code_dim,
|
583 |
-
args.code_dim,
|
584 |
-
args.down_t,
|
585 |
-
args.stride_t,
|
586 |
-
args.width,
|
587 |
-
args.depth,
|
588 |
-
args.dilation_growth_rate,
|
589 |
-
args.vq_act,
|
590 |
-
args.vq_norm)
|
591 |
-
|
592 |
-
self.vq_model_upper.load_state_dict(torch.load(self.args.vqvae_upper_path)['net'])
|
593 |
-
self.vq_model_hands.load_state_dict(torch.load(self.args.vqvae_hands_path)['net'])
|
594 |
-
self.vq_model_lower.load_state_dict(torch.load(self.args.vqvae_lower_path)['net'])
|
595 |
-
|
596 |
-
self.vqvae_latent_scale = self.args.vqvae_latent_scale
|
597 |
-
|
598 |
-
self.vq_model_upper.eval().to(self.rank)
|
599 |
-
self.vq_model_hands.eval().to(self.rank)
|
600 |
-
self.vq_model_lower.eval().to(self.rank)
|
601 |
-
|
602 |
-
self.model = self.model.cuda()
|
603 |
-
self.model.eval()
|
604 |
-
|
605 |
-
self.mean_upper = torch.from_numpy(self.mean_upper).cuda()
|
606 |
-
self.mean_hands = torch.from_numpy(self.mean_hands).cuda()
|
607 |
-
self.mean_lower = torch.from_numpy(self.mean_lower).cuda()
|
608 |
-
self.std_upper = torch.from_numpy(self.std_upper).cuda()
|
609 |
-
self.std_hands = torch.from_numpy(self.std_hands).cuda()
|
610 |
-
self.std_lower = torch.from_numpy(self.std_lower).cuda()
|
611 |
-
self.trans_mean = torch.from_numpy(self.trans_mean).cuda()
|
612 |
-
self.trans_std = torch.from_numpy(self.trans_std).cuda()
|
613 |
-
|
614 |
-
@spaces.GPU(duration=149)
|
615 |
-
def _warp(self, batch_data):
|
616 |
-
self._create_cuda_model()
|
617 |
-
|
618 |
-
|
619 |
-
loaded_data = self._load_data(batch_data)
|
620 |
-
net_out = self._g_test(loaded_data)
|
621 |
-
return net_out
|
622 |
|
623 |
def test_demo(self, epoch):
|
624 |
'''
|
@@ -644,7 +267,7 @@ class BaseTrainer(object):
|
|
644 |
for its, batch_data in enumerate(self.test_loader):
|
645 |
# loaded_data = self._load_data(batch_data)
|
646 |
# net_out = self._g_test(loaded_data)
|
647 |
-
net_out = self.
|
648 |
tar_pose = net_out['tar_pose']
|
649 |
rec_pose = net_out['rec_pose']
|
650 |
tar_exps = net_out['tar_exps']
|
@@ -708,7 +331,402 @@ class BaseTrainer(object):
|
|
708 |
end_time = time.time() - start_time
|
709 |
logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion")
|
710 |
return result
|
711 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
712 |
@logger.catch
|
713 |
def syntalker(audio_path,sample_stratege):
|
714 |
args = config.parse_args()
|
|
|
242 |
original_shape_t[i, selected_indices] = filtered_t[i]
|
243 |
return original_shape_t
|
244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
|
246 |
def test_demo(self, epoch):
|
247 |
'''
|
|
|
267 |
for its, batch_data in enumerate(self.test_loader):
|
268 |
# loaded_data = self._load_data(batch_data)
|
269 |
# net_out = self._g_test(loaded_data)
|
270 |
+
net_out = _warp(self.args,self.model, batch_data,self.joints,self.joint_mask_upper,self.joint_mask_hands,self.joint_mask_lower,self.use_trans,self.diffusion)
|
271 |
tar_pose = net_out['tar_pose']
|
272 |
rec_pose = net_out['rec_pose']
|
273 |
tar_exps = net_out['tar_exps']
|
|
|
331 |
end_time = time.time() - start_time
|
332 |
logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion")
|
333 |
return result
|
334 |
+
|
335 |
+
|
336 |
+
@spaces.GPU(duration=149)
|
337 |
+
def _warp(args,model, batch_data,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,use_trans,diffusion):
|
338 |
+
args,model,vq_model_upper,vq_model_hands,vq_model_lower,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vqvae_latent_scale=_warp_create_cuda_model(args,model)
|
339 |
+
|
340 |
+
|
341 |
+
loaded_data = _warp_load_data(
|
342 |
+
batch_data,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,args,use_trans,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vq_model_upper,vq_model_hands,vq_model_lower
|
343 |
+
)
|
344 |
+
net_out = _warp_g_test(loaded_data,diffusion,args,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,model,vqvae_latent_scale,vq_model_upper,vq_model_hands,vq_model_lower,use_trans,trans_std,trans_mean,std_upper,std_hands,std_lower,mean_upper,mean_hands,mean_lower)
|
345 |
+
return net_out
|
346 |
+
|
347 |
+
def _warp_inverse_selection_tensor(filtered_t, selection_array, n):
|
348 |
+
selection_array = torch.from_numpy(selection_array).cuda()
|
349 |
+
original_shape_t = torch.zeros((n, 165)).cuda()
|
350 |
+
selected_indices = torch.where(selection_array == 1)[0]
|
351 |
+
for i in range(n):
|
352 |
+
original_shape_t[i, selected_indices] = filtered_t[i]
|
353 |
+
return original_shape_t
|
354 |
+
|
355 |
+
def _warp_g_test(loaded_data,diffusion,args,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,model,vqvae_latent_scale,vq_model_upper,vq_model_hands,vq_model_lower,use_trans,trans_std,trans_mean,std_upper,std_hands,std_lower,mean_upper,mean_hands,mean_lower):
|
356 |
+
sample_fn = diffusion.p_sample_loop
|
357 |
+
if args.use_ddim:
|
358 |
+
sample_fn = diffusion.ddim_sample_loop
|
359 |
+
mode = 'test'
|
360 |
+
bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], joints
|
361 |
+
tar_pose = loaded_data["tar_pose"]
|
362 |
+
tar_beta = loaded_data["tar_beta"]
|
363 |
+
tar_exps = loaded_data["tar_exps"]
|
364 |
+
tar_contact = loaded_data["tar_contact"]
|
365 |
+
tar_trans = loaded_data["tar_trans"]
|
366 |
+
in_word = loaded_data["in_word"]
|
367 |
+
in_audio = loaded_data["in_audio"]
|
368 |
+
in_x0 = loaded_data['latent_in']
|
369 |
+
in_seed = loaded_data['latent_in']
|
370 |
+
|
371 |
+
remain = n%8
|
372 |
+
if remain != 0:
|
373 |
+
tar_pose = tar_pose[:, :-remain, :]
|
374 |
+
tar_beta = tar_beta[:, :-remain, :]
|
375 |
+
tar_trans = tar_trans[:, :-remain, :]
|
376 |
+
in_word = in_word[:, :-remain]
|
377 |
+
tar_exps = tar_exps[:, :-remain, :]
|
378 |
+
tar_contact = tar_contact[:, :-remain, :]
|
379 |
+
in_x0 = in_x0[:, :in_x0.shape[1]-(remain//args.vqvae_squeeze_scale), :]
|
380 |
+
in_seed = in_seed[:, :in_x0.shape[1]-(remain//args.vqvae_squeeze_scale), :]
|
381 |
+
n = n - remain
|
382 |
+
|
383 |
+
tar_pose_jaw = tar_pose[:, :, 66:69]
|
384 |
+
tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3))
|
385 |
+
tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6)
|
386 |
+
tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2)
|
387 |
+
|
388 |
+
tar_pose_hands = tar_pose[:, :, 25*3:55*3]
|
389 |
+
tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3))
|
390 |
+
tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6)
|
391 |
+
|
392 |
+
tar_pose_upper = tar_pose[:, :, joint_mask_upper.astype(bool)]
|
393 |
+
tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3))
|
394 |
+
tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6)
|
395 |
+
|
396 |
+
tar_pose_leg = tar_pose[:, :, joint_mask_lower.astype(bool)]
|
397 |
+
tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3))
|
398 |
+
tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6)
|
399 |
+
tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2)
|
400 |
+
|
401 |
+
tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3))
|
402 |
+
tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6)
|
403 |
+
latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1)
|
404 |
+
|
405 |
+
rec_all_face = []
|
406 |
+
rec_all_upper = []
|
407 |
+
rec_all_lower = []
|
408 |
+
rec_all_hands = []
|
409 |
+
vqvae_squeeze_scale = args.vqvae_squeeze_scale
|
410 |
+
roundt = (n - args.pre_frames * vqvae_squeeze_scale) // (args.pose_length - args.pre_frames * vqvae_squeeze_scale)
|
411 |
+
remain = (n - args.pre_frames * vqvae_squeeze_scale) % (args.pose_length - args.pre_frames * vqvae_squeeze_scale)
|
412 |
+
round_l = args.pose_length - args.pre_frames * vqvae_squeeze_scale
|
413 |
+
|
414 |
+
|
415 |
+
for i in range(0, roundt):
|
416 |
+
in_word_tmp = in_word[:, i*(round_l):(i+1)*(round_l)+args.pre_frames * vqvae_squeeze_scale]
|
417 |
+
|
418 |
+
in_audio_tmp = in_audio[:, i*(16000//30*round_l):(i+1)*(16000//30*round_l)+16000//30*args.pre_frames * vqvae_squeeze_scale]
|
419 |
+
in_id_tmp = loaded_data['tar_id'][:, i*(round_l):(i+1)*(round_l)+args.pre_frames]
|
420 |
+
in_seed_tmp = in_seed[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+args.pre_frames]
|
421 |
+
in_x0_tmp = in_x0[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+args.pre_frames]
|
422 |
+
mask_val = torch.ones(bs, args.pose_length, args.pose_dims+3+4).float().cuda()
|
423 |
+
mask_val[:, :args.pre_frames, :] = 0.0
|
424 |
+
if i == 0:
|
425 |
+
in_seed_tmp = in_seed_tmp[:, :args.pre_frames, :]
|
426 |
+
else:
|
427 |
+
in_seed_tmp = last_sample[:, -args.pre_frames:, :]
|
428 |
+
|
429 |
+
cond_ = {'y':{}}
|
430 |
+
cond_['y']['audio'] = in_audio_tmp
|
431 |
+
cond_['y']['word'] = in_word_tmp
|
432 |
+
cond_['y']['id'] = in_id_tmp
|
433 |
+
cond_['y']['seed'] =in_seed_tmp
|
434 |
+
cond_['y']['mask'] = (torch.zeros([args.batch_size, 1, 1, args.pose_length]) < 1).cuda()
|
435 |
+
|
436 |
+
|
437 |
+
|
438 |
+
cond_['y']['style_feature'] = torch.zeros([bs, 512]).cuda()
|
439 |
+
|
440 |
+
shape_ = (bs, 1536, 1, 32)
|
441 |
+
sample = sample_fn(
|
442 |
+
model,
|
443 |
+
shape_,
|
444 |
+
clip_denoised=False,
|
445 |
+
model_kwargs=cond_,
|
446 |
+
skip_timesteps=0,
|
447 |
+
init_image=None,
|
448 |
+
progress=True,
|
449 |
+
dump_steps=None,
|
450 |
+
noise=None,
|
451 |
+
const_noise=False,
|
452 |
+
)
|
453 |
+
sample = sample.squeeze().permute(1,0).unsqueeze(0)
|
454 |
+
|
455 |
+
last_sample = sample.clone()
|
456 |
+
|
457 |
+
rec_latent_upper = sample[...,:512]
|
458 |
+
rec_latent_hands = sample[...,512:1024]
|
459 |
+
rec_latent_lower = sample[...,1024:1536]
|
460 |
+
|
461 |
+
|
462 |
+
|
463 |
+
if i == 0:
|
464 |
+
rec_all_upper.append(rec_latent_upper)
|
465 |
+
rec_all_hands.append(rec_latent_hands)
|
466 |
+
rec_all_lower.append(rec_latent_lower)
|
467 |
+
else:
|
468 |
+
rec_all_upper.append(rec_latent_upper[:, args.pre_frames:])
|
469 |
+
rec_all_hands.append(rec_latent_hands[:, args.pre_frames:])
|
470 |
+
rec_all_lower.append(rec_latent_lower[:, args.pre_frames:])
|
471 |
+
|
472 |
+
rec_all_upper = torch.cat(rec_all_upper, dim=1) * vqvae_latent_scale
|
473 |
+
rec_all_hands = torch.cat(rec_all_hands, dim=1) * vqvae_latent_scale
|
474 |
+
rec_all_lower = torch.cat(rec_all_lower, dim=1) * vqvae_latent_scale
|
475 |
+
|
476 |
+
rec_upper = vq_model_upper.latent2origin(rec_all_upper)[0]
|
477 |
+
rec_hands = vq_model_hands.latent2origin(rec_all_hands)[0]
|
478 |
+
rec_lower = vq_model_lower.latent2origin(rec_all_lower)[0]
|
479 |
+
|
480 |
+
|
481 |
+
if use_trans:
|
482 |
+
rec_trans_v = rec_lower[...,-3:]
|
483 |
+
rec_trans_v = rec_trans_v * trans_std + trans_mean
|
484 |
+
rec_trans = torch.zeros_like(rec_trans_v)
|
485 |
+
rec_trans = torch.cumsum(rec_trans_v, dim=-2)
|
486 |
+
rec_trans[...,1]=rec_trans_v[...,1]
|
487 |
+
rec_lower = rec_lower[...,:-3]
|
488 |
+
|
489 |
+
if args.pose_norm:
|
490 |
+
rec_upper = rec_upper * std_upper + mean_upper
|
491 |
+
rec_hands = rec_hands * std_hands + mean_hands
|
492 |
+
rec_lower = rec_lower * std_lower + mean_lower
|
493 |
+
|
494 |
+
|
495 |
+
|
496 |
+
|
497 |
+
n = n - remain
|
498 |
+
tar_pose = tar_pose[:, :n, :]
|
499 |
+
tar_exps = tar_exps[:, :n, :]
|
500 |
+
tar_trans = tar_trans[:, :n, :]
|
501 |
+
tar_beta = tar_beta[:, :n, :]
|
502 |
+
|
503 |
+
|
504 |
+
rec_exps = tar_exps
|
505 |
+
#rec_pose_jaw = rec_face[:, :, :6]
|
506 |
+
rec_pose_legs = rec_lower[:, :, :54]
|
507 |
+
bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1]
|
508 |
+
rec_pose_upper = rec_upper.reshape(bs, n, 13, 6)
|
509 |
+
rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)#
|
510 |
+
rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3)
|
511 |
+
rec_pose_upper_recover = _warp_inverse_selection_tensor(rec_pose_upper, joint_mask_upper, bs*n)
|
512 |
+
rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6)
|
513 |
+
rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower)
|
514 |
+
rec_lower2global = rc.matrix_to_rotation_6d(rec_pose_lower.clone()).reshape(bs, n, 9*6)
|
515 |
+
rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3)
|
516 |
+
rec_pose_lower_recover = _warp_inverse_selection_tensor(rec_pose_lower, joint_mask_lower, bs*n)
|
517 |
+
rec_pose_hands = rec_hands.reshape(bs, n, 30, 6)
|
518 |
+
rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands)
|
519 |
+
rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3)
|
520 |
+
rec_pose_hands_recover = _warp_inverse_selection_tensor(rec_pose_hands, joint_mask_hands, bs*n)
|
521 |
+
rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover
|
522 |
+
rec_pose[:, 66:69] = tar_pose.reshape(bs*n, 55*3)[:, 66:69]
|
523 |
+
|
524 |
+
rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs*n, j, 3))
|
525 |
+
rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6)
|
526 |
+
tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs*n, j, 3))
|
527 |
+
tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6)
|
528 |
+
|
529 |
+
return {
|
530 |
+
'rec_pose': rec_pose,
|
531 |
+
'rec_trans': rec_trans,
|
532 |
+
'tar_pose': tar_pose,
|
533 |
+
'tar_exps': tar_exps,
|
534 |
+
'tar_beta': tar_beta,
|
535 |
+
'tar_trans': tar_trans,
|
536 |
+
'rec_exps': rec_exps,
|
537 |
+
}
|
538 |
+
|
539 |
+
|
540 |
+
|
541 |
+
def _warp_load_data(dict_data,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,args,use_trans,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vq_model_upper,vq_model_hands,vq_model_lower):
|
542 |
+
tar_pose_raw = dict_data["pose"]
|
543 |
+
tar_pose = tar_pose_raw[:, :, :165].cuda()
|
544 |
+
tar_contact = tar_pose_raw[:, :, 165:169].cuda()
|
545 |
+
tar_trans = dict_data["trans"].cuda()
|
546 |
+
tar_trans_v = dict_data["trans_v"].cuda()
|
547 |
+
tar_exps = dict_data["facial"].cuda()
|
548 |
+
in_audio = dict_data["audio"].cuda()
|
549 |
+
in_word = dict_data["word"].cuda()
|
550 |
+
tar_beta = dict_data["beta"].cuda()
|
551 |
+
tar_id = dict_data["id"].cuda().long()
|
552 |
+
bs, n, j = tar_pose.shape[0], tar_pose.shape[1], joints
|
553 |
+
|
554 |
+
tar_pose_jaw = tar_pose[:, :, 66:69]
|
555 |
+
tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3))
|
556 |
+
tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6)
|
557 |
+
tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2)
|
558 |
+
|
559 |
+
tar_pose_hands = tar_pose[:, :, 25*3:55*3]
|
560 |
+
tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3))
|
561 |
+
tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6)
|
562 |
+
|
563 |
+
tar_pose_upper = tar_pose[:, :, joint_mask_upper.astype(bool)]
|
564 |
+
tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3))
|
565 |
+
tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6)
|
566 |
+
|
567 |
+
tar_pose_leg = tar_pose[:, :, joint_mask_lower.astype(bool)]
|
568 |
+
tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3))
|
569 |
+
tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6)
|
570 |
+
|
571 |
+
tar_pose_lower = tar_pose_leg
|
572 |
+
|
573 |
+
|
574 |
+
tar4dis = torch.cat([tar_pose_jaw, tar_pose_upper, tar_pose_hands, tar_pose_leg], dim=2)
|
575 |
+
|
576 |
+
|
577 |
+
if args.pose_norm:
|
578 |
+
tar_pose_upper = (tar_pose_upper - mean_upper) / std_upper
|
579 |
+
tar_pose_hands = (tar_pose_hands - mean_hands) / std_hands
|
580 |
+
tar_pose_lower = (tar_pose_lower - mean_lower) / std_lower
|
581 |
+
|
582 |
+
if use_trans:
|
583 |
+
tar_trans_v = (tar_trans_v - trans_mean)/trans_std
|
584 |
+
tar_pose_lower = torch.cat([tar_pose_lower,tar_trans_v], dim=-1)
|
585 |
+
|
586 |
+
latent_face_top = None#self.vq_model_face.map2latent(tar_pose_face) # bs*n/4
|
587 |
+
latent_upper_top = vq_model_upper.map2latent(tar_pose_upper)
|
588 |
+
latent_hands_top = vq_model_hands.map2latent(tar_pose_hands)
|
589 |
+
latent_lower_top = vq_model_lower.map2latent(tar_pose_lower)
|
590 |
+
|
591 |
+
latent_in = torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2)/args.vqvae_latent_scale
|
592 |
+
|
593 |
+
|
594 |
+
tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3))
|
595 |
+
tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6)
|
596 |
+
latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1)
|
597 |
+
style_feature = None
|
598 |
+
if args.use_motionclip:
|
599 |
+
motionclip_feat = tar_pose_6d[...,:22*6]
|
600 |
+
batch = {}
|
601 |
+
bs,seq,feat = motionclip_feat.shape
|
602 |
+
batch['x']=motionclip_feat.permute(0,2,1).contiguous()
|
603 |
+
batch['y']=torch.zeros(bs).int().cuda()
|
604 |
+
batch['mask']=torch.ones([bs,seq]).bool().cuda()
|
605 |
+
style_feature = motionclip.encoder(batch)['mu'].detach().float()
|
606 |
+
|
607 |
+
|
608 |
+
|
609 |
+
# print(tar_index_value_upper_top.shape, index_in.shape)
|
610 |
+
return {
|
611 |
+
"tar_pose_jaw": tar_pose_jaw,
|
612 |
+
"tar_pose_face": tar_pose_face,
|
613 |
+
"tar_pose_upper": tar_pose_upper,
|
614 |
+
"tar_pose_lower": tar_pose_lower,
|
615 |
+
"tar_pose_hands": tar_pose_hands,
|
616 |
+
'tar_pose_leg': tar_pose_leg,
|
617 |
+
"in_audio": in_audio,
|
618 |
+
"in_word": in_word,
|
619 |
+
"tar_trans": tar_trans,
|
620 |
+
"tar_exps": tar_exps,
|
621 |
+
"tar_beta": tar_beta,
|
622 |
+
"tar_pose": tar_pose,
|
623 |
+
"tar4dis": tar4dis,
|
624 |
+
"latent_face_top": latent_face_top,
|
625 |
+
"latent_upper_top": latent_upper_top,
|
626 |
+
"latent_hands_top": latent_hands_top,
|
627 |
+
"latent_lower_top": latent_lower_top,
|
628 |
+
"latent_in": latent_in,
|
629 |
+
"tar_id": tar_id,
|
630 |
+
"latent_all": latent_all,
|
631 |
+
"tar_pose_6d": tar_pose_6d,
|
632 |
+
"tar_contact": tar_contact,
|
633 |
+
"style_feature":style_feature,
|
634 |
+
}
|
635 |
+
|
636 |
+
|
637 |
+
def _warp_create_cuda_model(args,model):
|
638 |
+
args = args
|
639 |
+
other_tools.load_checkpoints(model, args.test_ckpt, args.g_name)
|
640 |
+
args.num_quantizers = 6
|
641 |
+
args.shared_codebook = False
|
642 |
+
args.quantize_dropout_prob = 0.2
|
643 |
+
args.mu = 0.99
|
644 |
+
|
645 |
+
args.nb_code = 512
|
646 |
+
args.code_dim = 512
|
647 |
+
args.code_dim = 512
|
648 |
+
args.down_t = 2
|
649 |
+
args.stride_t = 2
|
650 |
+
args.width = 512
|
651 |
+
args.depth = 3
|
652 |
+
args.dilation_growth_rate = 3
|
653 |
+
args.vq_act = "relu"
|
654 |
+
args.vq_norm = None
|
655 |
+
|
656 |
+
dim_pose = 78
|
657 |
+
args.body_part = "upper"
|
658 |
+
vq_model_upper = RVQVAE(args,
|
659 |
+
dim_pose,
|
660 |
+
args.nb_code,
|
661 |
+
args.code_dim,
|
662 |
+
args.code_dim,
|
663 |
+
args.down_t,
|
664 |
+
args.stride_t,
|
665 |
+
args.width,
|
666 |
+
args.depth,
|
667 |
+
args.dilation_growth_rate,
|
668 |
+
args.vq_act,
|
669 |
+
args.vq_norm)
|
670 |
+
|
671 |
+
dim_pose = 180
|
672 |
+
args.body_part = "hands"
|
673 |
+
vq_model_hands = RVQVAE(args,
|
674 |
+
dim_pose,
|
675 |
+
args.nb_code,
|
676 |
+
args.code_dim,
|
677 |
+
args.code_dim,
|
678 |
+
args.down_t,
|
679 |
+
args.stride_t,
|
680 |
+
args.width,
|
681 |
+
args.depth,
|
682 |
+
args.dilation_growth_rate,
|
683 |
+
args.vq_act,
|
684 |
+
args.vq_norm)
|
685 |
+
|
686 |
+
dim_pose = 54
|
687 |
+
if args.use_trans:
|
688 |
+
dim_pose = 57
|
689 |
+
args.vqvae_lower_path = args.vqvae_lower_trans_path
|
690 |
+
args.body_part = "lower"
|
691 |
+
vq_model_lower = RVQVAE(args,
|
692 |
+
dim_pose,
|
693 |
+
args.nb_code,
|
694 |
+
args.code_dim,
|
695 |
+
args.code_dim,
|
696 |
+
args.down_t,
|
697 |
+
args.stride_t,
|
698 |
+
args.width,
|
699 |
+
args.depth,
|
700 |
+
args.dilation_growth_rate,
|
701 |
+
args.vq_act,
|
702 |
+
args.vq_norm)
|
703 |
+
|
704 |
+
vq_model_upper.load_state_dict(torch.load(args.vqvae_upper_path)['net'])
|
705 |
+
vq_model_hands.load_state_dict(torch.load(args.vqvae_hands_path)['net'])
|
706 |
+
vq_model_lower.load_state_dict(torch.load(args.vqvae_lower_path)['net'])
|
707 |
+
|
708 |
+
vqvae_latent_scale = args.vqvae_latent_scale
|
709 |
+
|
710 |
+
vq_model_upper.eval().cuda()
|
711 |
+
vq_model_hands.eval().cuda()
|
712 |
+
vq_model_lower.eval().cuda()
|
713 |
+
|
714 |
+
model = model.cuda()
|
715 |
+
model.eval()
|
716 |
+
|
717 |
+
mean_upper = torch.from_numpy(mean_upper).cuda()
|
718 |
+
mean_hands = torch.from_numpy(mean_hands).cuda()
|
719 |
+
mean_lower = torch.from_numpy(mean_lower).cuda()
|
720 |
+
std_upper = torch.from_numpy(std_upper).cuda()
|
721 |
+
std_hands = torch.from_numpy(std_hands).cuda()
|
722 |
+
std_lower = torch.from_numpy(std_lower).cuda()
|
723 |
+
trans_mean = torch.from_numpy(trans_mean).cuda()
|
724 |
+
trans_std = torch.from_numpy(trans_std).cuda()
|
725 |
+
|
726 |
+
return args,model,vq_model_upper,vq_model_hands,vq_model_lower,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vqvae_latent_scale
|
727 |
+
|
728 |
+
|
729 |
+
|
730 |
@logger.catch
|
731 |
def syntalker(audio_path,sample_stratege):
|
732 |
args = config.parse_args()
|