Spaces:
Running
on
L4
Running
on
L4
import pickle | |
import time | |
import numpy as np | |
import scipy, cv2, os, sys, argparse | |
from tqdm import tqdm | |
import torch | |
import librosa | |
from networks import define_G | |
from pcavs.config.AudioConfig import AudioConfig | |
sys.path.append('spectre') | |
from config import cfg as spectre_cfg | |
from src.spectre import SPECTRE | |
from audio2mesh_helper import * | |
from pcavs.models import create_model, networks | |
torch.manual_seed(0) | |
from scipy.signal import savgol_filter | |
class SimpleWrapperV2(nn.Module): | |
def __init__(self, cfg, use_ref=True, exp_dim=53, noload=False) -> None: | |
super().__init__() | |
self.audio_encoder = networks.define_A_sync(cfg) | |
self.mapping1 = nn.Linear(512+exp_dim, exp_dim) | |
nn.init.constant_(self.mapping1.weight, 0.) | |
nn.init.constant_(self.mapping1.bias, 0.) | |
self.use_ref = use_ref | |
def forward(self, x, ref, use_tanh=False): | |
x = self.audio_encoder.forward_feature(x).view(x.size(0), -1) | |
ref_reshape = ref.reshape(x.size(0), -1) #20, -1 | |
y = self.mapping1(torch.cat([x, ref_reshape], dim=1)) | |
if self.use_ref: | |
out = y.reshape(ref.shape[0], ref.shape[1], -1) + ref # resudial | |
else: | |
out = y.reshape(ref.shape[0], ref.shape[1], -1) | |
if use_tanh: | |
out[:, :50] = torch.tanh(out[:, :50]) * 3 | |
return out | |
class Audio2Mesh(object): | |
def __init__(self, args) -> None: | |
self.args = args | |
spectre_cfg.model.use_tex = True | |
spectre_cfg.model.mask_type = args.mask_type | |
spectre_cfg.debug = self.args.debug | |
spectre_cfg.model.netA_sync = 'ressesync' | |
spectre_cfg.model.gpu_ids = [0] | |
self.spectre = SPECTRE(spectre_cfg) | |
self.spectre.eval() | |
self.face_tracker = None #FaceTrackerV2() # face landmark detection | |
self.mel_step_size = 16 | |
self.fps = args.fps | |
self.Nw = args.tframes | |
self.device = self.args.device | |
self.image_size = self.args.image_size | |
### only audio | |
args.netA_sync = 'ressesync' | |
args.gpu_ids = [0] | |
args.exp_dim = 53 | |
args.use_tanh = False | |
args.K = 20 | |
self.audio2exp = 'pcavs' | |
# | |
self.avmodel = SimpleWrapperV2(args, exp_dim=args.exp_dim).cuda() | |
self.avmodel.load_state_dict(torch.load('../packages/pretrained/audio2expression_v2_model.tar')['opt']) | |
# 5, 160 = 25fps | |
self.audio = AudioConfig(frame_rate=args.fps, num_frames_per_clip=5, hop_size=160) | |
with open(os.path.join(args.source_dir, 'deca_infos.pkl'), 'rb') as f: # ? | |
self.fitting_coeffs = pickle.load(f, encoding='bytes') | |
self.coeffs_dict = { key: torch.Tensor(self.fitting_coeffs[key]).cuda().squeeze(1) for key in ['cam', 'pose', 'light', 'tex', 'shape', 'exp']} | |
#### find the close month | |
exp_tensors = torch.sum(self.coeffs_dict['exp'], dim=1) | |
ssss, sorted_indices = torch.sort(exp_tensors) | |
self.exp_id = sorted_indices[0].item() | |
if '.ts' in args.render_path: | |
self.render = torch.jit.load(args.render_path).cuda() | |
self.trt = True | |
else: | |
self.render = define_G(self.Nw*6, 3, args.ngf, args.netR).eval().cuda() | |
self.render.load_state_dict(torch.load(args.render_path)) | |
self.trt = False | |
print('loaded cached images...') | |
def cg2real(self, rendedimages, start_frame=0): | |
## load original image and the mask | |
self.source_images = np.concatenate(load_image_from_dir(os.path.join(self.args.source_dir, 'original_frame'),\ | |
resize=self.image_size, limit=len(rendedimages)+start_frame))[start_frame:] | |
self.source_masks = np.concatenate(load_image_from_dir(os.path.join(self.args.source_dir, 'original_mask'),\ | |
resize=self.image_size, limit=len(rendedimages)+start_frame))[start_frame:] | |
self.source_masks = torch.FloatTensor(np.transpose(self.source_masks,(0,3,1,2))/255.) | |
self.padded_real_tensor = torch.FloatTensor(np.transpose(self.source_images,(0,3,1,2))/255.) | |
## padding the rended_imgs | |
paded_tensor = torch.cat([rendedimages[0:1]]* (self.Nw // 2) + [rendedimages] + [rendedimages[-1:]]* (self.Nw // 2)).contiguous() | |
paded_mask_tensor = torch.cat([self.source_masks[0:1]]* (self.Nw // 2) + [self.source_masks] + [self.source_masks[-1:]]* (self.Nw // 2)).contiguous() | |
paded_real_tensor = torch.cat([self.padded_real_tensor[0:1]]* (self.Nw // 2) + [self.padded_real_tensor] + [self.padded_real_tensor[-1:]]* (self.Nw // 2)).contiguous() | |
# paded_mask_tensor = maskErosion(paded_mask_tensor, offY=self.args.mask) | |
padded_input = ((paded_real_tensor-0.5)*2 ) # *(1-paded_mask_tensor) | |
padded_input = torch.nn.functional.interpolate(padded_input, (self.image_size, self.image_size), mode='bilinear', align_corners=False) | |
paded_tensor = torch.nn.functional.interpolate(paded_tensor, (self.image_size, self.image_size), mode='bilinear', align_corners=False) | |
paded_tensor = (paded_tensor-0.5)*2 | |
result = [] | |
for index in tqdm(range(0, len(rendedimages), self.args.renderbs), desc='CG2REAL:'): | |
list_A = [] | |
list_R = [] | |
list_M = [] | |
for i in range(self.args.renderbs): | |
idx = index + i | |
if idx+self.Nw > len(padded_input): | |
list_A.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0)) | |
list_R.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0)) | |
list_M.append(torch.zeros(self.Nw*3,self.image_size,self.image_size).unsqueeze(0)) | |
else: | |
list_A.append(padded_input[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0)) | |
list_R.append(paded_tensor[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0)) | |
list_M.append(paded_mask_tensor[idx:idx+self.Nw].view(-1, self.image_size, self.image_size).unsqueeze(0)) | |
list_A = torch.cat(list_A) | |
list_R = torch.cat(list_R) | |
list_M = torch.cat(list_M) | |
idx = (self.Nw//2) * 3 | |
mask = list_M[:, idx:idx+3] | |
# list_A = padded_input | |
mask = maskErosion(mask, offY=self.args.mask) | |
list_A = list_A * (1 - mask[:,0:1]) | |
A = torch.cat([list_A, list_R], 1) | |
if self.trt: | |
B = self.render(A.half().cuda()) | |
elif self.args.netR == 'unet_256': | |
# import pdb; pdb.set_trace() | |
idx = (self.Nw//2) * 3 | |
mask = list_M[:, idx:idx+3].cuda() | |
mask = maskErosion(mask, offY=self.args.mask) | |
B0 = list_A[:, idx:idx+3].cuda() | |
B = self.render(A.cuda()) * mask[:,0:1] + (1 - mask[:,0:1]) * B0 | |
elif self.args.netR == 's2am': | |
# import pdb; pdb.set_trace() | |
idx = (self.Nw//2) * 3 | |
mask = list_M[:, idx:idx+3].cuda() | |
mask = maskErosion(mask, offY=self.args.mask) | |
B0 = list_A[:, idx:idx+3].cuda() | |
B = self.render(A.cuda(), mask[:,0:1] ) * mask[:,0:1] + (1 - mask[:,0:1]) * B0 | |
else: | |
B = self.render(A.cuda()) | |
result.append((B.cpu() + 1) * 0.5) # -1,1 -> 0,1 | |
return torch.cat(result)[:len(rendedimages)] | |
def coeffs_to_img(self, vertices, coeffs, zero_pose=False, XK = 20): | |
xlen = vertices.shape[0] | |
all_shape_images = [] | |
landmark2d = [] | |
#### find the most larger pose 51 in the coeffs. | |
max_pose_51 = torch.max(self.coeffs_dict['pose'][..., 3:4].squeeze(-1)) | |
for i in tqdm(range(0, xlen, XK)): | |
if i + XK > xlen: | |
XK = xlen - i | |
codedictdecoder = {} | |
codedictdecoder['shape'] = torch.zeros_like(self.coeffs_dict['shape'][i:i+XK].cuda()) | |
codedictdecoder['tex'] = self.coeffs_dict['tex'][i:i+XK].cuda() | |
codedictdecoder['exp'] = torch.zeros_like(self.coeffs_dict['exp'][i:i+XK].cuda()) # all_exps[i:i+XK, :50].cuda() # # # vid_exps[i:i+1].cuda() i:i+XK | |
codedictdecoder['pose'] = self.coeffs_dict['pose'][i:i+XK] # vid_poses[i:i+1].cuda() | |
codedictdecoder['cam'] = self.coeffs_dict['cam'][i:i+XK].cuda() # vid_poses[i:i+1].cuda() | |
codedictdecoder['light'] = self.coeffs_dict['light'][i:i+XK].cuda() # vid_poses[i:i+1].cuda() | |
codedictdecoder['images'] = torch.zeros((XK,3,256,256)).cuda() | |
codedictdecoder['pose'][..., 3:4] = torch.clip(coeffs[i:i+XK, 50:51], 0, max_pose_51*0.9) # torch.zeros_like(self.coeffs_dict['pose'][i:i+XK, 3:]) | |
codedictdecoder['pose'][..., 4:6] = 0 # coeffs[i:i+XK, 50:]*( - 0.25) # torch.zeros_like(self.coeffs_dict['pose'][i:i+XK, 3:]) | |
sub_vertices = vertices[i:i+XK].cuda() | |
opdict = self.spectre.decode_verts(codedictdecoder, sub_vertices, rendering=True, vis_lmk=False, return_vis=False) | |
landmark2d.append(opdict['landmarks2d'].cpu()) | |
all_shape_images.append(opdict['rendered_images'].cpu()) | |
rendedimages = torch.cat(all_shape_images) | |
lmk2d = torch.cat(landmark2d) | |
return rendedimages, lmk2d | |
def run_spectre_v3(self, wav=None, ds_features=None, L=20): | |
wav = audio_normalize(wav) | |
all_mel = self.audio.melspectrogram(wav).astype(np.float32).T | |
frames_from_audio = np.arange(2, len(all_mel) // self.audio.num_bins_per_frame - 2) # 2,[]mmmmmmmmmmmmmmmmmmmmmmmmmmmm | |
audio_inds = frame2audio_indexs(frames_from_audio, self.audio.num_frames_per_clip, self.audio.num_bins_per_frame) | |
vid_exps = self.coeffs_dict['exp'][self.exp_id:self.exp_id+1] | |
vid_poses = self.coeffs_dict['pose'][self.exp_id:self.exp_id+1] | |
ref = torch.cat([vid_exps.view(1, 50), vid_poses[:, 3:].view(1, 3)], dim=-1) | |
ref = ref[...,:self.args.exp_dim] | |
K = 20 | |
xlens = len(audio_inds) # len(self.coeffs_dict['exp']) | |
exps = [] | |
for i in tqdm(range(0, xlens, K), desc='S2 DECODER:'+ str(xlens) + ' '): | |
mels = [] | |
for j in range(K): | |
if i + j < xlens: | |
idx = i+j # //3 * 3 | |
mel = load_spectrogram(all_mel, audio_inds[idx], self.audio.num_frames_per_clip * self.audio.num_bins_per_frame).cuda() | |
mel = mel.view(-1, 1, 80, self.audio.num_frames_per_clip * self.audio.num_bins_per_frame) | |
mels.append(mel) | |
else: | |
break | |
mels = torch.cat(mels, dim=0) | |
new_exp = self.avmodel(mels, ref.repeat(mels.shape[0], 1, 1).cuda(), self.args.use_tanh) # exp 53 | |
exps+= [new_exp.view(-1, 53)] | |
all_exps = torch.cat(exps,axis=0) | |
return all_exps | |
def test_model(self, wav_path): | |
sys.path.append('../FaceFormer') | |
from faceformer import Faceformer | |
from transformers import Wav2Vec2FeatureExtractor,Wav2Vec2Processor | |
from faceformer import PeriodicPositionalEncoding, init_biased_mask | |
#build model | |
self.args.train_subjects = " ".join(["A"]*8) # suitable for pre-trained faceformer checkpoint | |
model = Faceformer(self.args) | |
model.load_state_dict(torch.load('/apdcephfs/private_shadowcun/Avatar2dFF/medias/videos/c8/mask5000_l2/6_model.pth')) # ../packages/pretrained/28_ff_model.pth | |
model = model.to(torch.device(self.device)) | |
model.eval() | |
# hacking for long audio generation | |
model.PPE = PeriodicPositionalEncoding(self.args.feature_dim, period = self.args.period, max_seq_len=6000).cuda() | |
model.biased_mask = init_biased_mask(n_head = 4, max_seq_len = 6000, period=self.args.period).cuda() | |
train_subjects_list = ["A"] * 8 | |
one_hot_labels = np.eye(len(train_subjects_list)) | |
one_hot = one_hot_labels[0] | |
one_hot = np.reshape(one_hot,(-1,one_hot.shape[0])) | |
one_hot = torch.FloatTensor(one_hot).to(device=self.device) | |
vertices_npy = np.load(self.args.source_dir + '/mesh_pose0.npy') | |
vertices_npy = np.array(vertices_npy).reshape(-1, 5023*3) | |
temp = vertices_npy[33] # 829 | |
template = temp.reshape((-1)) | |
template = np.reshape(template,(-1,template.shape[0])) | |
template = torch.FloatTensor(template).to(device=self.device) | |
speech_array, sampling_rate = librosa.load(os.path.join(wav_path), sr=16000) | |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") | |
audio_feature = np.squeeze(processor(speech_array,sampling_rate=16000).input_values) | |
audio_feature = np.reshape(audio_feature,(-1,audio_feature.shape[0])) | |
audio_feature = torch.FloatTensor(audio_feature).to(device=self.device) | |
prediction = model.predict(audio_feature, template, one_hot, 1.0) # (1, seq_len, V*3) | |
return prediction.squeeze() | |
def run(self, face, audio, start_frame=0): | |
wav, sr = librosa.load(audio, sr=16000) # 16*80 ? 20*80 | |
wav_tensor = torch.FloatTensor(wav).unsqueeze(0) if len(wav.shape) == 1 else torch.FloatTensor(wav) | |
_, frames = parse_audio_length(wav_tensor.shape[1], 16000, self.args.fps) | |
##### audio-guided, only use the jaw movement | |
all_exps = self.run_spectre_v3(wav) | |
# #### temp. interpolation | |
all_exps = torch.nn.functional.interpolate(all_exps.unsqueeze(0).permute([0,2,1]), size=frames, mode='linear') | |
all_exps = all_exps.permute([0,2,1]).squeeze(0) | |
# run faceformer for face mesh generation | |
predicted_vertices = self.test_model(audio) | |
predicted_vertices = predicted_vertices.view(-1, 5023*3) | |
#### temp. interpolation | |
predicted_vertices = torch.nn.functional.interpolate(predicted_vertices.unsqueeze(0).permute([0,2,1]), size=frames, mode='linear') | |
predicted_vertices = predicted_vertices.permute([0,2,1]).squeeze(0).view(-1, 5023, 3) | |
all_exps = torch.Tensor(savgol_filter(all_exps.cpu().numpy(), 5, 3, axis=0)).cpu() # smooth GT | |
rendedimages, lm2d = self.coeffs_to_img(predicted_vertices, all_exps, zero_pose=True) | |
debug_video_gen(rendedimages, self.args.result_dir+"/debug_before_ff.mp4", wav_tensor, self.args.fps, sr) | |
# cg2real | |
debug_video_gen(self.cg2real(rendedimages, start_frame=start_frame), self.args.result_dir+"/debug_cg2real_raw.mp4", wav_tensor, self.args.fps, sr) | |
exit() | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Stylization and Seamless Video Dubbing') | |
parser.add_argument('--face', default='examples', type=str, help='') | |
parser.add_argument('--audio', default='examples', type=str, help='') | |
parser.add_argument('--source_dir', default='examples', type=str,help='TODO') | |
parser.add_argument('--result_dir', default='examples', type=str,help='TODO') | |
parser.add_argument('--backend', default='wav2lip', type=str,help='wav2lip or pcavs') | |
parser.add_argument('--result_tag', default='result', type=str,help='TODO') | |
parser.add_argument('--netR', default='unet_256', type=str,help='TODO') | |
parser.add_argument('--render_path', default='', type=str,help='TODO') | |
parser.add_argument('--ngf', default=16, type=int,help='TODO') | |
parser.add_argument('--fps', default=20, type=int,help='TODO') | |
parser.add_argument('--mask', default=100, type=int,help='TODO') | |
parser.add_argument('--mask_type', default='v3', type=str,help='TODO') | |
parser.add_argument('--image_size', default=256, type=int,help='TODO') | |
parser.add_argument('--input_nc', default=21, type=int,help='TODO') | |
parser.add_argument('--output_nc', default=3, type=int,help='TODO') | |
parser.add_argument('--renderbs', default=16, type=int,help='TODO') | |
parser.add_argument('--tframes', default=1, type=int,help='TODO') | |
parser.add_argument('--debug', action='store_true') | |
parser.add_argument('--enhance', action='store_true') | |
parser.add_argument('--phone', action='store_true') | |
#### faceformer | |
parser.add_argument("--model_name", type=str, default="VOCA") | |
parser.add_argument("--dataset", type=str, default="vocaset", help='vocaset or BIWI') | |
parser.add_argument("--feature_dim", type=int, default=64, help='64 for vocaset; 128 for BIWI') | |
parser.add_argument("--period", type=int, default=30, help='period in PPE - 30 for vocaset; 25 for BIWI') | |
parser.add_argument("--vertice_dim", type=int, default=5023*3, help='number of vertices - 5023*3 for vocaset; 23370*3 for BIWI') | |
parser.add_argument("--device", type=str, default="cuda") | |
parser.add_argument("--train_subjects", type=str, default="FaceTalk_170728_03272_TA ") | |
parser.add_argument("--test_subjects", type=str, default="FaceTalk_170809_00138_TA FaceTalk_170731_00024_TA") | |
parser.add_argument("--condition", type=str, default="FaceTalk_170904_00128_TA", help='select a conditioning subject from train_subjects') | |
parser.add_argument("--subject", type=str, default="FaceTalk_170731_00024_TA", help='select a subject from test_subjects or train_subjects') | |
parser.add_argument("--background_black", type=bool, default=True, help='whether to use black background') | |
parser.add_argument("--template_path", type=str, default="templates.pkl", help='path of the personalized templates') | |
parser.add_argument("--render_template_path", type=str, default="templates", help='path of the mesh in BIWI/FLAME topology') | |
opt = parser.parse_args() | |
opt.img_size = 96 | |
opt.static = True | |
opt.device = torch.device("cuda") | |
a2m = Audio2Mesh(opt) | |
print('link start!') | |
t = time.time() | |
# 02780 | |
a2m.run(opt.face, opt.audio, 0) | |
print(time.time() - t) |