import os import signal import time import csv import sys import warnings import random import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP import torch.multiprocessing as mp import numpy as np import time import pprint from loguru import logger import smplx from torch.utils.tensorboard import SummaryWriter import wandb import matplotlib.pyplot as plt from utils import config, logger_tools, other_tools_hf, metric, data_transfer, other_tools from dataloaders import data_tools from dataloaders.build_vocab import Vocab from optimizers.optim_factory import create_optimizer from optimizers.scheduler_factory import create_scheduler from optimizers.loss_factory import get_loss_func from dataloaders.data_tools import joints_list from utils import rotation_conversions as rc import soundfile as sf import librosa import subprocess from transformers import pipeline from diffusion.model_util import create_gaussian_diffusion from diffusion.resample import create_named_schedule_sampler from models.vq.model import RVQVAE import spaces import pickle os.environ['PYOPENGL_PLATFORM']='egl' command = ["bash","./demo/install_mfa1.sh"] result = subprocess.run(command, capture_output=True, text=True) print("debug0: ", result) # command = ["bash","./demo/install_mfa.sh"] # result = subprocess.run(command, capture_output=True, text=True) # print("debug1: ", result) device = "cuda" if torch.cuda.is_available() else "cpu" pipe = pipeline( "automatic-speech-recognition", model="openai/whisper-tiny.en", chunk_length_s=30, device='cpu', ) # @spaces.GPU() def run_pipeline(audio): return pipe(audio, batch_size=8)["text"] debug = False class BaseTrainer(object): def __init__(self, args,ap): args.use_ddim=True hf_dir = "hf" time_local = time.localtime() time_name_expend = "%02d%02d_%02d%02d%02d_"%(time_local[1], time_local[2],time_local[3], time_local[4], time_local[5]) self.time_name_expend = time_name_expend tmp_dir = args.out_path + "custom/"+ time_name_expend + hf_dir if not os.path.exists(tmp_dir + "/"): os.makedirs(tmp_dir + "/") self.audio_path = tmp_dir + "/tmp.wav" sf.write(self.audio_path, ap[1], ap[0]) audio, ssr = librosa.load(self.audio_path,sr=args.audio_sr) # use asr model to get corresponding text transcripts file_path = tmp_dir+"/tmp.lab" self.textgrid_path = tmp_dir + "/tmp.TextGrid" if not debug: text = run_pipeline(audio) with open(file_path, "w", encoding="utf-8") as file: file.write(text) # use montreal forced aligner to get textgrid # command = ["mfa", "align", tmp_dir, "english_us_arpa", "english_us_arpa", tmp_dir] # result = subprocess.run(command, capture_output=True, text=True) # print("debug2: ", result) command = ["bash","./demo/run_mfa.sh", tmp_dir] result = subprocess.run(command, capture_output=True, text=True) print("debug2: ", result) ap = (ssr, audio) self.args = args self.rank = 0 # dist.get_rank() args.textgrid_file_path = self.textgrid_path args.audio_file_path = self.audio_path self.rank = 0 # dist.get_rank() self.checkpoint_path = tmp_dir args.tmp_dir = tmp_dir self.test_data = __import__(f"dataloaders.{args.dataset}", fromlist=["something"]).CustomDataset(args, "test") self.test_loader = torch.utils.data.DataLoader( self.test_data, batch_size=1, shuffle=False, num_workers=args.loader_workers, drop_last=False, ) logger.info(f"Init test dataloader success") from models.denoiser import MDM self.model = MDM(args) if self.rank == 0: logger.info(self.model) logger.info(f"init {args.g_name} success") self.args = args self.ori_joint_list = joints_list[self.args.ori_joints] self.tar_joint_list_face = joints_list["beat_smplx_face"] self.tar_joint_list_upper = joints_list["beat_smplx_upper"] self.tar_joint_list_hands = joints_list["beat_smplx_hands"] self.tar_joint_list_lower = joints_list["beat_smplx_lower"] self.joint_mask_face = np.zeros(len(list(self.ori_joint_list.keys()))*3) self.joints = 55 for joint_name in self.tar_joint_list_face: self.joint_mask_face[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 self.joint_mask_upper = np.zeros(len(list(self.ori_joint_list.keys()))*3) for joint_name in self.tar_joint_list_upper: self.joint_mask_upper[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 self.joint_mask_hands = np.zeros(len(list(self.ori_joint_list.keys()))*3) for joint_name in self.tar_joint_list_hands: self.joint_mask_hands[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 self.joint_mask_lower = np.zeros(len(list(self.ori_joint_list.keys()))*3) for joint_name in self.tar_joint_list_lower: self.joint_mask_lower[self.ori_joint_list[joint_name][1] - self.ori_joint_list[joint_name][0]:self.ori_joint_list[joint_name][1]] = 1 self.tracker = other_tools.EpochTracker(["fid", "l1div", "bc", "rec", "trans", "vel", "transv", 'dis', 'gen', 'acc', 'transa', 'exp', 'lvd', 'mse', "cls", "rec_face", "latent", "cls_full", "cls_self", "cls_word", "latent_word","latent_self","predict_x0_loss"], [False,True,True, False, False, False, False, False, False, False, False, False, False, False, False, False, False,False, False, False,False,False,False]) vq_model_module = __import__(f"models.motion_representation", fromlist=["something"]) self.args.vae_layer = 2 self.args.vae_length = 256 self.args.vae_test_dim = 106 # self.vq_model_face = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank) # other_tools.load_checkpoints(self.vq_model_face, "./datasets/hub/pretrained_vq/face_vertex_1layer_790.bin", args.e_name) vq_type = self.args.vqvae_type if vq_type=="vqvae": self.args.vae_layer = 4 self.args.vae_test_dim = 78 self.vq_model_upper = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank) other_tools.load_checkpoints(self.vq_model_upper, args.vqvae_upper_path, args.e_name) self.args.vae_test_dim = 180 self.vq_model_hands = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank) other_tools.load_checkpoints(self.vq_model_hands, args.vqvae_hands_path, args.e_name) self.args.vae_test_dim = 54 self.args.vae_layer = 4 self.vq_model_lower = getattr(vq_model_module, "VQVAEConvZero")(self.args).to(self.rank) other_tools.load_checkpoints(self.vq_model_lower, args.vqvae_lower_path, args.e_name) self.args.vae_test_dim = 61 self.args.vae_layer = 4 self.args.vae_test_dim = 330 self.args.vae_layer = 4 self.args.vae_length = 240 self.cls_loss = nn.NLLLoss().to(self.rank) self.reclatent_loss = nn.MSELoss().to(self.rank) self.vel_loss = torch.nn.L1Loss(reduction='mean').to(self.rank) self.rec_loss = get_loss_func("GeodesicLoss").to(self.rank) self.log_softmax = nn.LogSoftmax(dim=2).to(self.rank) self.diffusion = create_gaussian_diffusion(use_ddim=args.use_ddim) self.schedule_sampler_type = 'uniform' self.schedule_sampler = create_named_schedule_sampler(self.schedule_sampler_type, self.diffusion) self.mean = np.load(args.mean_pose_path) self.std = np.load(args.std_pose_path) self.use_trans = args.use_trans if self.use_trans: self.trans_mean = np.load(args.mean_trans_path) self.trans_std = np.load(args.std_trans_path) joints = [3,6,9,12,13,14,15,16,17,18,19,20,21] upper_body_mask = [] for i in joints: upper_body_mask.extend([i*6, i*6+1, i*6+2, i*6+3, i*6+4, i*6+5]) joints = list(range(25,55)) hands_body_mask = [] for i in joints: hands_body_mask.extend([i*6, i*6+1, i*6+2, i*6+3, i*6+4, i*6+5]) joints = [0,1,2,4,5,7,8,10,11] lower_body_mask = [] for i in joints: lower_body_mask.extend([i*6, i*6+1, i*6+2, i*6+3, i*6+4, i*6+5]) self.mean_upper = self.mean[upper_body_mask] self.mean_hands = self.mean[hands_body_mask] self.mean_lower = self.mean[lower_body_mask] self.std_upper = self.std[upper_body_mask] self.std_hands = self.std[hands_body_mask] self.std_lower = self.std[lower_body_mask] def inverse_selection(self, filtered_t, selection_array, n): original_shape_t = np.zeros((n, selection_array.size)) selected_indices = np.where(selection_array == 1)[0] for i in range(n): original_shape_t[i, selected_indices] = filtered_t[i] return original_shape_t def inverse_selection_tensor(self, filtered_t, selection_array, n): selection_array = torch.from_numpy(selection_array).cuda() original_shape_t = torch.zeros((n, 165)).cuda() selected_indices = torch.where(selection_array == 1)[0] for i in range(n): original_shape_t[i, selected_indices] = filtered_t[i] return original_shape_t def test_demo(self, epoch): ''' input audio and text, output motion do not calculate loss and metric save video ''' results_save_path = self.checkpoint_path + f"/{epoch}/" if os.path.exists(results_save_path): import shutil shutil.rmtree(results_save_path) os.makedirs(results_save_path) start_time = time.time() total_length = 0 test_seq_list = self.test_data.selected_file align = 0 latent_out = [] latent_ori = [] l2_all = 0 lvel = 0 # self.eval_copy.eval() with torch.no_grad(): for its, batch_data in enumerate(self.test_loader): # loaded_data = self._load_data(batch_data) # net_out = self._g_test(loaded_data) try: net_out = _warp(self.args,self.model, batch_data,self.joints,self.joint_mask_upper,self.joint_mask_hands,self.joint_mask_lower,self.use_trans,self.mean_upper,self.mean_hands,self.mean_lower,self.std_upper,self.std_hands,self.std_lower,self.trans_mean,self.trans_std) print("debug8: return try") except: print("debug9: return fail, use pickle load file") with open("tmp_file", "rb") as tmp_file: net_out = pickle.load(tmp_file) tar_pose = net_out['tar_pose'] rec_pose = net_out['rec_pose'] tar_exps = net_out['tar_exps'] tar_beta = net_out['tar_beta'] rec_trans = net_out['rec_trans'] tar_trans = net_out['tar_trans'] rec_exps = net_out['rec_exps'] bs, n, j = tar_pose.shape[0], tar_pose.shape[1], self.joints if (30/self.args.pose_fps) != 1: assert 30%self.args.pose_fps == 0 n *= int(30/self.args.pose_fps) tar_pose = torch.nn.functional.interpolate(tar_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) rec_pose = torch.nn.functional.interpolate(rec_pose.permute(0, 2, 1), scale_factor=30/self.args.pose_fps, mode='linear').permute(0,2,1) rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6)) rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6)) tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) rec_pose = rc.rotation_6d_to_matrix(rec_pose.reshape(bs*n, j, 6)) rec_pose = rc.matrix_to_axis_angle(rec_pose).reshape(bs*n, j*3) tar_pose = rc.rotation_6d_to_matrix(tar_pose.reshape(bs*n, j, 6)) tar_pose = rc.matrix_to_axis_angle(tar_pose).reshape(bs*n, j*3) tar_pose_np = tar_pose.numpy() rec_pose_np = rec_pose.numpy() rec_trans_np = rec_trans.numpy().reshape(bs*n, 3) rec_exp_np = rec_exps.numpy().reshape(bs*n, 100) tar_exp_np = tar_exps.numpy().reshape(bs*n, 100) tar_trans_np = tar_trans.numpy().reshape(bs*n, 3) gt_npz = np.load("./demo/examples/2_scott_0_1_1.npz", allow_pickle=True) results_npz_file_save_path = results_save_path+f"result_{self.time_name_expend[:-1]}"+'.npz' np.savez(results_npz_file_save_path, betas=gt_npz["betas"], poses=rec_pose_np, expressions=rec_exp_np, trans=rec_trans_np, model='smplx2020', gender='neutral', mocap_frame_rate = 30, ) total_length += n render_vid_path = None if self.args.render_video: render_vid_path = other_tools_hf.render_one_sequence_no_gt( results_npz_file_save_path, # results_save_path+"gt_"+test_seq_list.iloc[its]['id']+'.npz', results_save_path, self.audio_path, self.args.data_path_1+"smplx_models/", use_matplotlib = False, args = self.args, ) result = [ gr.Video(value=render_vid_path, visible=True), gr.File(value=results_npz_file_save_path, label="download motion and visualize in blender"), ] end_time = time.time() - start_time logger.info(f"total inference time: {int(end_time)} s for {int(total_length/self.args.pose_fps)} s motion") return result @spaces.GPU(duration=60) def _warp(args,model, batch_data,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,use_trans,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std): diffusion = create_gaussian_diffusion(use_ddim=args.use_ddim) args,model,vq_model_upper,vq_model_hands,vq_model_lower,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vqvae_latent_scale=_warp_create_cuda_model(args,model,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std) loaded_data = _warp_load_data( batch_data,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,args,use_trans,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vq_model_upper,vq_model_hands,vq_model_lower ) net_out = _warp_g_test(loaded_data,diffusion,args,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,model,vqvae_latent_scale,vq_model_upper,vq_model_hands,vq_model_lower,use_trans,trans_std,trans_mean,std_upper,std_hands,std_lower,mean_upper,mean_hands,mean_lower) with open("tmp_file", "wb") as tmp_file: pickle.dump(net_out, tmp_file) return net_out def _warp_inverse_selection_tensor(filtered_t, selection_array, n): selection_array = torch.from_numpy(selection_array).cuda() original_shape_t = torch.zeros((n, 165)).cuda() selected_indices = torch.where(selection_array == 1)[0] for i in range(n): original_shape_t[i, selected_indices] = filtered_t[i] return original_shape_t def _warp_g_test(loaded_data,diffusion,args,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,model,vqvae_latent_scale,vq_model_upper,vq_model_hands,vq_model_lower,use_trans,trans_std,trans_mean,std_upper,std_hands,std_lower,mean_upper,mean_hands,mean_lower): sample_fn = diffusion.p_sample_loop if args.use_ddim: sample_fn = diffusion.ddim_sample_loop mode = 'test' bs, n, j = loaded_data["tar_pose"].shape[0], loaded_data["tar_pose"].shape[1], joints tar_pose = loaded_data["tar_pose"] tar_beta = loaded_data["tar_beta"] tar_exps = loaded_data["tar_exps"] tar_contact = loaded_data["tar_contact"] tar_trans = loaded_data["tar_trans"] in_word = loaded_data["in_word"] in_audio = loaded_data["in_audio"] in_x0 = loaded_data['latent_in'] in_seed = loaded_data['latent_in'] remain = n%8 if remain != 0: tar_pose = tar_pose[:, :-remain, :] tar_beta = tar_beta[:, :-remain, :] tar_trans = tar_trans[:, :-remain, :] in_word = in_word[:, :-remain] tar_exps = tar_exps[:, :-remain, :] tar_contact = tar_contact[:, :-remain, :] in_x0 = in_x0[:, :in_x0.shape[1]-(remain//args.vqvae_squeeze_scale), :] in_seed = in_seed[:, :in_x0.shape[1]-(remain//args.vqvae_squeeze_scale), :] n = n - remain tar_pose_jaw = tar_pose[:, :, 66:69] tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3)) tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6) tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2) tar_pose_hands = tar_pose[:, :, 25*3:55*3] tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3)) tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6) tar_pose_upper = tar_pose[:, :, joint_mask_upper.astype(bool)] tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3)) tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6) tar_pose_leg = tar_pose[:, :, joint_mask_lower.astype(bool)] tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3)) tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6) tar_pose_lower = torch.cat([tar_pose_leg, tar_trans, tar_contact], dim=2) tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3)) tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6) latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1) rec_all_face = [] rec_all_upper = [] rec_all_lower = [] rec_all_hands = [] vqvae_squeeze_scale = args.vqvae_squeeze_scale roundt = (n - args.pre_frames * vqvae_squeeze_scale) // (args.pose_length - args.pre_frames * vqvae_squeeze_scale) remain = (n - args.pre_frames * vqvae_squeeze_scale) % (args.pose_length - args.pre_frames * vqvae_squeeze_scale) round_l = args.pose_length - args.pre_frames * vqvae_squeeze_scale print("debug3:finish it!") for i in range(0, roundt): in_word_tmp = in_word[:, i*(round_l):(i+1)*(round_l)+args.pre_frames * vqvae_squeeze_scale] in_audio_tmp = in_audio[:, i*(16000//30*round_l):(i+1)*(16000//30*round_l)+16000//30*args.pre_frames * vqvae_squeeze_scale] in_id_tmp = loaded_data['tar_id'][:, i*(round_l):(i+1)*(round_l)+args.pre_frames] in_seed_tmp = in_seed[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+args.pre_frames] in_x0_tmp = in_x0[:, i*(round_l)//vqvae_squeeze_scale:(i+1)*(round_l)//vqvae_squeeze_scale+args.pre_frames] mask_val = torch.ones(bs, args.pose_length, args.pose_dims+3+4).float().cuda() mask_val[:, :args.pre_frames, :] = 0.0 if i == 0: in_seed_tmp = in_seed_tmp[:, :args.pre_frames, :] else: in_seed_tmp = last_sample[:, -args.pre_frames:, :] cond_ = {'y':{}} cond_['y']['audio'] = in_audio_tmp cond_['y']['word'] = in_word_tmp cond_['y']['id'] = in_id_tmp cond_['y']['seed'] =in_seed_tmp cond_['y']['mask'] = (torch.zeros([args.batch_size, 1, 1, args.pose_length]) < 1).cuda() cond_['y']['style_feature'] = torch.zeros([bs, 512]).cuda() shape_ = (bs, 1536, 1, 32) sample = sample_fn( model, shape_, clip_denoised=False, model_kwargs=cond_, skip_timesteps=0, init_image=None, progress=True, dump_steps=None, noise=None, const_noise=False, ) sample = sample.squeeze().permute(1,0).unsqueeze(0) last_sample = sample.clone() rec_latent_upper = sample[...,:512] rec_latent_hands = sample[...,512:1024] rec_latent_lower = sample[...,1024:1536] if i == 0: rec_all_upper.append(rec_latent_upper) rec_all_hands.append(rec_latent_hands) rec_all_lower.append(rec_latent_lower) else: rec_all_upper.append(rec_latent_upper[:, args.pre_frames:]) rec_all_hands.append(rec_latent_hands[:, args.pre_frames:]) rec_all_lower.append(rec_latent_lower[:, args.pre_frames:]) print("debug4:finish it!") rec_all_upper = torch.cat(rec_all_upper, dim=1) * vqvae_latent_scale rec_all_hands = torch.cat(rec_all_hands, dim=1) * vqvae_latent_scale rec_all_lower = torch.cat(rec_all_lower, dim=1) * vqvae_latent_scale rec_upper = vq_model_upper.latent2origin(rec_all_upper)[0] rec_hands = vq_model_hands.latent2origin(rec_all_hands)[0] rec_lower = vq_model_lower.latent2origin(rec_all_lower)[0] if use_trans: rec_trans_v = rec_lower[...,-3:] rec_trans_v = rec_trans_v * trans_std + trans_mean rec_trans = torch.zeros_like(rec_trans_v) rec_trans = torch.cumsum(rec_trans_v, dim=-2) rec_trans[...,1]=rec_trans_v[...,1] rec_lower = rec_lower[...,:-3] if args.pose_norm: rec_upper = rec_upper * std_upper + mean_upper rec_hands = rec_hands * std_hands + mean_hands rec_lower = rec_lower * std_lower + mean_lower n = n - remain tar_pose = tar_pose[:, :n, :] tar_exps = tar_exps[:, :n, :] tar_trans = tar_trans[:, :n, :] tar_beta = tar_beta[:, :n, :] rec_exps = tar_exps #rec_pose_jaw = rec_face[:, :, :6] rec_pose_legs = rec_lower[:, :, :54] bs, n = rec_pose_legs.shape[0], rec_pose_legs.shape[1] rec_pose_upper = rec_upper.reshape(bs, n, 13, 6) rec_pose_upper = rc.rotation_6d_to_matrix(rec_pose_upper)# rec_pose_upper = rc.matrix_to_axis_angle(rec_pose_upper).reshape(bs*n, 13*3) rec_pose_upper_recover = _warp_inverse_selection_tensor(rec_pose_upper, joint_mask_upper, bs*n) rec_pose_lower = rec_pose_legs.reshape(bs, n, 9, 6) rec_pose_lower = rc.rotation_6d_to_matrix(rec_pose_lower) rec_lower2global = rc.matrix_to_rotation_6d(rec_pose_lower.clone()).reshape(bs, n, 9*6) rec_pose_lower = rc.matrix_to_axis_angle(rec_pose_lower).reshape(bs*n, 9*3) rec_pose_lower_recover = _warp_inverse_selection_tensor(rec_pose_lower, joint_mask_lower, bs*n) rec_pose_hands = rec_hands.reshape(bs, n, 30, 6) rec_pose_hands = rc.rotation_6d_to_matrix(rec_pose_hands) rec_pose_hands = rc.matrix_to_axis_angle(rec_pose_hands).reshape(bs*n, 30*3) rec_pose_hands_recover = _warp_inverse_selection_tensor(rec_pose_hands, joint_mask_hands, bs*n) rec_pose = rec_pose_upper_recover + rec_pose_lower_recover + rec_pose_hands_recover rec_pose[:, 66:69] = tar_pose.reshape(bs*n, 55*3)[:, 66:69] rec_pose = rc.axis_angle_to_matrix(rec_pose.reshape(bs*n, j, 3)) rec_pose = rc.matrix_to_rotation_6d(rec_pose).reshape(bs, n, j*6) tar_pose = rc.axis_angle_to_matrix(tar_pose.reshape(bs*n, j, 3)) tar_pose = rc.matrix_to_rotation_6d(tar_pose).reshape(bs, n, j*6) print("debug5:finish it!") return { 'rec_pose': rec_pose.detach().cpu(), 'rec_trans': rec_trans.detach().cpu(), 'tar_pose': tar_pose.detach().cpu(), 'tar_exps': tar_exps.detach().cpu(), 'tar_beta': tar_beta.detach().cpu(), 'tar_trans': tar_trans.detach().cpu(), 'rec_exps': rec_exps.detach().cpu(), } def _warp_load_data(dict_data,joints,joint_mask_upper,joint_mask_hands,joint_mask_lower,args,use_trans,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vq_model_upper,vq_model_hands,vq_model_lower): tar_pose_raw = dict_data["pose"] tar_pose = tar_pose_raw[:, :, :165].cuda() tar_contact = tar_pose_raw[:, :, 165:169].cuda() tar_trans = dict_data["trans"].cuda() tar_trans_v = dict_data["trans_v"].cuda() tar_exps = dict_data["facial"].cuda() in_audio = dict_data["audio"].cuda() in_word = dict_data["word"].cuda() tar_beta = dict_data["beta"].cuda() tar_id = dict_data["id"].cuda().long() bs, n, j = tar_pose.shape[0], tar_pose.shape[1], joints tar_pose_jaw = tar_pose[:, :, 66:69] tar_pose_jaw = rc.axis_angle_to_matrix(tar_pose_jaw.reshape(bs, n, 1, 3)) tar_pose_jaw = rc.matrix_to_rotation_6d(tar_pose_jaw).reshape(bs, n, 1*6) tar_pose_face = torch.cat([tar_pose_jaw, tar_exps], dim=2) tar_pose_hands = tar_pose[:, :, 25*3:55*3] tar_pose_hands = rc.axis_angle_to_matrix(tar_pose_hands.reshape(bs, n, 30, 3)) tar_pose_hands = rc.matrix_to_rotation_6d(tar_pose_hands).reshape(bs, n, 30*6) tar_pose_upper = tar_pose[:, :, joint_mask_upper.astype(bool)] tar_pose_upper = rc.axis_angle_to_matrix(tar_pose_upper.reshape(bs, n, 13, 3)) tar_pose_upper = rc.matrix_to_rotation_6d(tar_pose_upper).reshape(bs, n, 13*6) tar_pose_leg = tar_pose[:, :, joint_mask_lower.astype(bool)] tar_pose_leg = rc.axis_angle_to_matrix(tar_pose_leg.reshape(bs, n, 9, 3)) tar_pose_leg = rc.matrix_to_rotation_6d(tar_pose_leg).reshape(bs, n, 9*6) tar_pose_lower = tar_pose_leg tar4dis = torch.cat([tar_pose_jaw, tar_pose_upper, tar_pose_hands, tar_pose_leg], dim=2) if args.pose_norm: tar_pose_upper = (tar_pose_upper - mean_upper) / std_upper tar_pose_hands = (tar_pose_hands - mean_hands) / std_hands tar_pose_lower = (tar_pose_lower - mean_lower) / std_lower if use_trans: tar_trans_v = (tar_trans_v - trans_mean)/trans_std tar_pose_lower = torch.cat([tar_pose_lower,tar_trans_v], dim=-1) latent_face_top = None#self.vq_model_face.map2latent(tar_pose_face) # bs*n/4 latent_upper_top = vq_model_upper.map2latent(tar_pose_upper) latent_hands_top = vq_model_hands.map2latent(tar_pose_hands) latent_lower_top = vq_model_lower.map2latent(tar_pose_lower) latent_in = torch.cat([latent_upper_top, latent_hands_top, latent_lower_top], dim=2)/args.vqvae_latent_scale tar_pose_6d = rc.axis_angle_to_matrix(tar_pose.reshape(bs, n, 55, 3)) tar_pose_6d = rc.matrix_to_rotation_6d(tar_pose_6d).reshape(bs, n, 55*6) latent_all = torch.cat([tar_pose_6d, tar_trans, tar_contact], dim=-1) style_feature = None if args.use_motionclip: motionclip_feat = tar_pose_6d[...,:22*6] batch = {} bs,seq,feat = motionclip_feat.shape batch['x']=motionclip_feat.permute(0,2,1).contiguous() batch['y']=torch.zeros(bs).int().cuda() batch['mask']=torch.ones([bs,seq]).bool().cuda() style_feature = motionclip.encoder(batch)['mu'].detach().float() # print(tar_index_value_upper_top.shape, index_in.shape) return { "tar_pose_jaw": tar_pose_jaw, "tar_pose_face": tar_pose_face, "tar_pose_upper": tar_pose_upper, "tar_pose_lower": tar_pose_lower, "tar_pose_hands": tar_pose_hands, 'tar_pose_leg': tar_pose_leg, "in_audio": in_audio, "in_word": in_word, "tar_trans": tar_trans, "tar_exps": tar_exps, "tar_beta": tar_beta, "tar_pose": tar_pose, "tar4dis": tar4dis, "latent_face_top": latent_face_top, "latent_upper_top": latent_upper_top, "latent_hands_top": latent_hands_top, "latent_lower_top": latent_lower_top, "latent_in": latent_in, "tar_id": tar_id, "latent_all": latent_all, "tar_pose_6d": tar_pose_6d, "tar_contact": tar_contact, "style_feature":style_feature, } def _warp_create_cuda_model(args,model,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std): args = args other_tools.load_checkpoints(model, args.test_ckpt, args.g_name) args.num_quantizers = 6 args.shared_codebook = False args.quantize_dropout_prob = 0.2 args.mu = 0.99 args.nb_code = 512 args.code_dim = 512 args.code_dim = 512 args.down_t = 2 args.stride_t = 2 args.width = 512 args.depth = 3 args.dilation_growth_rate = 3 args.vq_act = "relu" args.vq_norm = None dim_pose = 78 args.body_part = "upper" vq_model_upper = RVQVAE(args, dim_pose, args.nb_code, args.code_dim, args.code_dim, args.down_t, args.stride_t, args.width, args.depth, args.dilation_growth_rate, args.vq_act, args.vq_norm) dim_pose = 180 args.body_part = "hands" vq_model_hands = RVQVAE(args, dim_pose, args.nb_code, args.code_dim, args.code_dim, args.down_t, args.stride_t, args.width, args.depth, args.dilation_growth_rate, args.vq_act, args.vq_norm) dim_pose = 54 if args.use_trans: dim_pose = 57 args.vqvae_lower_path = args.vqvae_lower_trans_path args.body_part = "lower" vq_model_lower = RVQVAE(args, dim_pose, args.nb_code, args.code_dim, args.code_dim, args.down_t, args.stride_t, args.width, args.depth, args.dilation_growth_rate, args.vq_act, args.vq_norm) vq_model_upper.load_state_dict(torch.load(args.vqvae_upper_path)['net']) vq_model_hands.load_state_dict(torch.load(args.vqvae_hands_path)['net']) vq_model_lower.load_state_dict(torch.load(args.vqvae_lower_path)['net']) vqvae_latent_scale = args.vqvae_latent_scale vq_model_upper.eval().cuda() vq_model_hands.eval().cuda() vq_model_lower.eval().cuda() model = model.cuda() model.eval() mean_upper = torch.from_numpy(mean_upper).cuda() mean_hands = torch.from_numpy(mean_hands).cuda() mean_lower = torch.from_numpy(mean_lower).cuda() std_upper = torch.from_numpy(std_upper).cuda() std_hands = torch.from_numpy(std_hands).cuda() std_lower = torch.from_numpy(std_lower).cuda() trans_mean = torch.from_numpy(trans_mean).cuda() trans_std = torch.from_numpy(trans_std).cuda() return args,model,vq_model_upper,vq_model_hands,vq_model_lower,mean_upper,mean_hands,mean_lower,std_upper,std_hands,std_lower,trans_mean,trans_std,vqvae_latent_scale @logger.catch def syntalker(audio_path,sample_stratege,render_video): args = config.parse_args() args.use_ddim=True args.render_video=True print("sample_stratege",sample_stratege) if sample_stratege==0: args.use_ddim=True elif sample_stratege==1: args.use_ddim=False if render_video==0: args.render_video=True elif render_video==1: args.render_video=False print(sample_stratege) print(args.use_ddim) #os.environ['TRANSFORMERS_CACHE'] = args.data_path_1 + "hub/" if not sys.warnoptions: warnings.simplefilter("ignore") # dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) #logger_tools.set_args_and_logger(args, rank) other_tools_hf.set_random_seed(args) other_tools_hf.print_exp_info(args) # return one intance of trainer trainer = BaseTrainer(args, ap = audio_path) result = trainer.test_demo(999) return result examples = [ ["demo/examples/2_scott_0_1_1.wav"], ["demo/examples/2_scott_0_2_2.wav"], ["demo/examples/2_scott_0_3_3.wav"], ["demo/examples/2_scott_0_4_4.wav"], ["demo/examples/2_scott_0_5_5.wav"], ] demo = gr.Interface( syntalker, # function inputs=[ # gr.File(label="Please upload SMPL-X file with npz format here.", file_types=["npz", "NPZ"]), gr.Audio(), gr.Radio(choices=["DDIM", "DDPM"], label="Please select a sample strategy", type="index", value="DDIM"), # 0 for DDIM, 1 for DDPM gr.Radio(choices=["Yes", "No"], label="Please select whether render video or not, it will additionally take 10 mintues for rendering", type="index", value="Yes"), # 0 for DDIM, 1 for DDPM # gr.File(label="Please upload textgrid format file here.", file_types=["TextGrid", "Textgrid", "textgrid"]) ], # input type outputs=[ gr.Video(format="mp4", visible=True), gr.File(label="download motion and visualize in blender") ], title='SynTalker: Enabling Synergistic Full-Body Control in Prompt-Based Co-Speech Motion Generation', description="1. Upload your audio.
\ 2. Then, sit back and wait for the rendering to happen! This may take a while (e.g. 2-12 minutes)
\ (The reason of running time so long is that provided GPU have an limitation in GPU running time, we must use CPU to handle some GPU tasks)
\ 3. After, you can view the videos.
\ 4. Notice that we use a fix face animation, our method only produce body motion.
\ 5. Use DDPM sample strategy will generate a better result, while it will take more inference time. \ ", article="Project links: [SynTalker](https://robinwitch.github.io/SynTalker-Page).
\ Reference links: [EMAGE](https://pantomatrix.github.io/EMAGE/). ", examples=examples, ) if __name__ == "__main__": os.environ["MASTER_ADDR"]='127.0.0.1' os.environ["MASTER_PORT"]='8675' #os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" demo.launch(share=True)