test_kdtalker / difpoint /inference.py
YinuoGuo27's picture
Update difpoint/inference.py
50df1fb verified
raw
history blame
21.7 kB
# -*- coding: UTF-8 -*-
'''
@File :inference.py
@Author :Chaolong Yang
@Date :2024/5/29 19:26
'''
import glob
import os
import os
import time
import shutil
import uuid
import os
import cv2
import tyro
from difpoint.src.utils.crop import crop_image, parse_bbox_from_landmark, crop_image_by_bbox, paste_back, paste_back_pytorch
from difpoint.src.utils.utils import resize_to_limit, prepare_paste_back, get_rotation_matrix, calc_lip_close_ratio, \
calc_eye_close_ratio, transform_keypoint, concat_feat
from difpoint.src.utils import utils
import numpy as np
from tqdm import tqdm
import cv2
from rich.progress import track
from difpoint.croper import Croper
from PIL import Image
import time
import torch
import torch.nn.functional as F
from torch import nn
import imageio
from pydub import AudioSegment
from pykalman import KalmanFilter
import scipy
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
from difpoint.dataset_process import audio
import os
import argparse
import pdb
import subprocess
import ffmpeg
import cv2
import time
import numpy as np
import os
import datetime
import platform
from omegaconf import OmegaConf
from difpoint.src.pipelines.faster_live_portrait_pipeline import FasterLivePortraitPipeline
FFMPEG = "ffmpeg"
def parse_audio_length(audio_length, sr, fps):
bit_per_frames = sr / fps
num_frames = int(audio_length / bit_per_frames)
audio_length = int(num_frames * bit_per_frames)
return audio_length, num_frames
def crop_pad_audio(wav, audio_length):
if len(wav) > audio_length:
wav = wav[:audio_length]
elif len(wav) < audio_length:
wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0)
return wav
class Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act=True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.Conv2d(cin, cout, kernel_size, stride, padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
self.residual = residual
self.use_act = use_act
def forward(self, x):
out = self.conv_block(x)
if self.residual:
out += x
if self.use_act:
return self.act(out)
else:
return out
class AudioEncoder(nn.Module):
def __init__(self, wav2lip_checkpoint, device):
super(AudioEncoder, self).__init__()
self.audio_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
#### load the pre-trained audio_encoder
wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
state_dict = self.audio_encoder.state_dict()
for k,v in wav2lip_state_dict.items():
if 'audio_encoder' in k:
state_dict[k.replace('module.audio_encoder.', '')] = v
self.audio_encoder.load_state_dict(state_dict)
def forward(self, audio_sequences):
# audio_sequences = (B, T, 1, 80, 16)
B = audio_sequences.size(0)
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
dim = audio_embedding.shape[1]
audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
def partial_fields(target_class, kwargs):
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)})
def dct2device(dct: dict, device):
for key in dct:
dct[key] = torch.tensor(dct[key]).to(device)
return dct
def save_video_with_watermark(video, audio, save_path, watermark=False):
temp_file = str(uuid.uuid4())+'.mp4'
cmd = r'ffmpeg -y -i "%s" -i "%s" -vcodec copy "%s"' % (video, audio, temp_file)
os.system(cmd)
shutil.move(temp_file, save_path)
class Inferencer(object):
def __init__(self):
st=time.time()
print('#'*25+'Start initialization'+'#'*25)
self.device = 'cuda'
from difpoint.model import get_model
self.point_diffusion = get_model()
ckpt = torch.load('./downloaded_repo/ckpts/KDTalker.pth', weights_only=False)
self.point_diffusion.load_state_dict(ckpt['model'])
print('model', self.point_diffusion.children())
self.point_diffusion.eval()
self.point_diffusion.to(self.device)
lm_croper_checkpoint = './downloaded_repo/ckpts/shape_predictor_68_face_landmarks.dat'
self.croper = Croper(lm_croper_checkpoint)
self.norm_info = dict(np.load(r'difpoint/datasets/norm_info_d6.5_c8.5_vox1_train.npz'))
wav2lip_checkpoint = './downloaded_repo/ckpts/wav2lip.pth'
self.wav2lip_model = AudioEncoder(wav2lip_checkpoint, 'cuda')
self.wav2lip_model.cuda()
self.wav2lip_model.eval()
# specify configs for inference
self.inf_cfg = OmegaConf.load("difpoint/configs/onnx_mp_infer.yaml")
self.inf_cfg.infer_params.flag_pasteback = False
self.live_portrait_pipeline = FasterLivePortraitPipeline(cfg=self.inf_cfg, is_animal=False)
#ret = self.live_portrait_pipeline.prepare_source(source_image)
print('#'*25+f'End initialization, cost time {time.time()-st}'+'#'*25)
def _norm(self, data_dict):
for k in data_dict.keys():
if k in ['yaw', 'pitch', 'roll', 't', 'scale', 'c_lip', 'c_eye']:
v=data_dict[k]
data_dict[k] = (v - self.norm_info[k+'_mean'])/self.norm_info[k+'_std']
elif k in ['exp', 'kp']:
v=data_dict[k]
data_dict[k] = (v - self.norm_info[k+'_mean'].reshape(1,21,3))/self.norm_info[k+'_std'].reshape(1,21,3)
return data_dict
def _denorm(self, data_dict):
for k in data_dict.keys():
if k in ['yaw', 'pitch', 'roll', 't', 'scale', 'c_lip', 'c_eye']:
v=data_dict[k]
data_dict[k] = v * self.norm_info[k+'_std'] + self.norm_info[k+'_mean']
elif k in ['exp', 'kp']:
v=data_dict[k]
data_dict[k] = v * self.norm_info[k+'_std'] + self.norm_info[k+'_mean']
return data_dict
def output_to_dict(self, data):
output = {}
output['scale'] = data[:, 0]
output['yaw'] = data[:, 1, None]
output['pitch'] = data[:, 2, None]
output['roll'] = data[:, 3, None]
output['t'] = data[:, 4:7]
output['exp'] = data[:, 7:]
return output
def extract_mel_from_audio(self, audio_file_path):
syncnet_mel_step_size = 16
fps = 25
wav = audio.load_wav(audio_file_path, 16000)
wav_length, num_frames = parse_audio_length(len(wav), 16000, 25)
wav = crop_pad_audio(wav, wav_length)
orig_mel = audio.melspectrogram(wav).T
spec = orig_mel.copy()
indiv_mels = []
for i in tqdm(range(num_frames), 'mel:'):
start_frame_num = i - 2
start_idx = int(80. * (start_frame_num / float(fps)))
end_idx = start_idx + syncnet_mel_step_size
seq = list(range(start_idx, end_idx))
seq = [min(max(item, 0), orig_mel.shape[0] - 1) for item in seq]
m = spec[seq, :]
indiv_mels.append(m.T)
indiv_mels = np.asarray(indiv_mels) # T 80 16
return indiv_mels
def extract_wav2lip_from_audio(self, audio_file_path):
asd_mel = self.extract_mel_from_audio(audio_file_path)
asd_mel = torch.FloatTensor(asd_mel).cuda().unsqueeze(0).unsqueeze(2)
with torch.no_grad():
hidden = self.wav2lip_model(asd_mel)
return hidden[0].cpu().detach().numpy()
def headpose_pred_to_degree(self, pred):
device = pred.device
idx_tensor = [idx for idx in range(66)]
idx_tensor = torch.FloatTensor(idx_tensor).to(device)
pred = F.softmax(pred)
degree = torch.sum(pred * idx_tensor, 1) * 3 - 99
return degree
def calc_combined_eye_ratio(self, c_d_eyes_i, c_s_eyes):
c_s_eyes_tensor = torch.from_numpy(c_s_eyes).float().to(self.device)
c_d_eyes_i_tensor = c_d_eyes_i[0].reshape(1, 1).to(self.device)
# [c_s,eyes, c_d,eyes,i]
combined_eye_ratio_tensor = torch.cat([c_s_eyes_tensor, c_d_eyes_i_tensor], dim=1)
return combined_eye_ratio_tensor
def calc_combined_lip_ratio(self, c_d_lip_i, c_s_lip):
c_s_lip_tensor = torch.from_numpy(c_s_lip).float().to(self.device)
c_d_lip_i_tensor = c_d_lip_i[0].to(self.device).reshape(1, 1) # 1x1
# [c_s,lip, c_d,lip,i]
combined_lip_ratio_tensor = torch.cat([c_s_lip_tensor, c_d_lip_i_tensor], dim=1) # 1x2
return combined_lip_ratio_tensor
# 2024.06.26
@torch.no_grad()
def generate_with_audio_img(self, upload_audio_path, tts_audio_path, audio_type, image_path, smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t, save_path='results'):
print(audio_type)
if audio_type == 'upload':
audio_path = upload_audio_path
elif audio_type == 'tts':
audio_path = tts_audio_path
save_path = os.path.join(save_path, "output.mp4")
image = [np.array(Image.open(image_path).convert('RGB'))]
if image[0].shape[0] != 256 or image[0].shape[1] != 256:
cropped_image, crop, quad = self.croper.crop(image, still=False, xsize=512)
input_image = cv2.resize(cropped_image[0], (256, 256))
else:
input_image = image[0]
I_s = torch.FloatTensor(input_image.transpose((2, 0, 1))).unsqueeze(0).cuda() / 255
pitch, yaw, roll, t, exp, scale, kp = self.live_portrait_pipeline.model_dict["motion_extractor"].predict(
I_s)
x_s_info = {
"pitch": pitch,
"yaw": yaw,
"roll": roll,
"t": t,
"exp": exp,
"scale": scale,
"kp": kp
}
x_c_s = kp.reshape(1, 21, -1)
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll'])
f_s = self.live_portrait_pipeline.model_dict["app_feat_extractor"].predict(I_s)
x_s = transform_keypoint(pitch, yaw, roll, t, exp, scale, kp)
flag_lip_zero = self.inf_cfg.infer_params.flag_normalize_lip
if flag_lip_zero:
# let lip-open scalar to be 0 at first
c_d_lip_before_animation = [0.]
lip_delta_before_animation = self.live_portrait_pipeline.model_dict['stitching_lip_retarget'].predict(
concat_feat(x_s, combined_lip_ratio_tensor_before_animation))
######## process driving info ########
kp_info = {}
for k in x_s_info.keys():
kp_info[k] = x_s_info[k]
# kp_info['c_lip'] = c_s_lip
# kp_info['c_eye'] = c_s_eye
kp_info = self._norm(kp_info)
ori_kp = torch.cat([torch.zeros([1, 7]).to('cuda'), torch.Tensor(kp_info['kp'].reshape(1,63)).to('cuda')], -1).cuda()
input_x = np.concatenate([kp_info[k] for k in ['scale', 'yaw', 'pitch', 'roll', 't']], 1)
input_x = np.concatenate((input_x, kp_info['exp'].reshape(1, 63)), axis=1)
input_x = np.expand_dims(input_x, -1)
input_x = np.expand_dims(input_x, 0)
input_x = np.concatenate([input_x, input_x, input_x], -1)
aud_feat = self.extract_wav2lip_from_audio(audio_path)
outputs = [input_x]
st = time.time()
print('#' * 25 + 'Start Inference' + '#' * 25)
sample_frame = 64 # 32 aud_feat.shape[0]
for i in range(0, aud_feat.shape[0] - 1, sample_frame):
input_mel = torch.Tensor(aud_feat[i: i + sample_frame]).unsqueeze(0).cuda()
kp0 = torch.Tensor(outputs[-1])[:, -1].cuda()
pred_kp = self.point_diffusion.forward_sample(70, ref_kps=kp0, ori_kps=ori_kp, aud_feat=input_mel,
scheduler='ddim', num_inference_steps=50)
outputs.append(pred_kp.cpu().numpy())
outputs = np.mean(np.concatenate(outputs, 1)[0], -1)[1:, ]
output_dict = self.output_to_dict(outputs)
output_dict = self._denorm(output_dict)
num_frame = output_dict['yaw'].shape[0]
x_d_info = {}
for key in output_dict:
x_d_info[key] = torch.tensor(output_dict[key]).cuda()
# smooth
def smooth(sequence, n_dim_state=1):
kf = KalmanFilter(initial_state_mean=sequence[0],
transition_covariance=0.05 * np.eye(n_dim_state), # 较小的过程噪声
observation_covariance=0.001 * np.eye(n_dim_state)) # 可以增大观测噪声,减少敏感性
state_means, _ = kf.smooth(sequence)
return state_means
# scale_data = x_d_info['scale'].cpu().numpy()
yaw_data = x_d_info['yaw'].cpu().numpy()
pitch_data = x_d_info['pitch'].cpu().numpy()
roll_data = x_d_info['roll'].cpu().numpy()
t_data = x_d_info['t'].cpu().numpy()
exp_data = x_d_info['exp'].cpu().numpy()
smoothed_pitch = smooth(pitch_data, n_dim_state=1) * smoothed_pitch
smoothed_yaw = smooth(yaw_data, n_dim_state=1) * smoothed_yaw
smoothed_roll = smooth(roll_data, n_dim_state=1) * smoothed_roll
# smoothed_scale = smooth(scale_data, n_dim_state=1)
smoothed_t = smooth(t_data, n_dim_state=3) * smoothed_t
smoothed_exp = smooth(exp_data, n_dim_state=63)
# x_d_info['scale'] = torch.Tensor(smoothed_scale).cuda()
x_d_info['pitch'] = torch.Tensor(smoothed_pitch).cuda()
x_d_info['yaw'] = torch.Tensor(smoothed_yaw).cuda()
x_d_info['roll'] = torch.Tensor(smoothed_roll).cuda()
x_d_info['t'] = torch.Tensor(smoothed_t).cuda()
x_d_info['exp'] = torch.Tensor(smoothed_exp).cuda()
template_dct = {'motion': [], 'c_d_eyes_lst': [], 'c_d_lip_lst': []}
for i in track(range(num_frame), description='Making motion templates...', total=num_frame):
# collect s_d, R_d, δ_d and t_d for inference
x_d_i_info = x_d_info
R_d_i = get_rotation_matrix(x_d_i_info['pitch'][i], x_d_i_info['yaw'][i], x_d_i_info['roll'][i])
item_dct = {
'scale': x_d_i_info['scale'][i].cpu().numpy().astype(np.float32),
'R_d': R_d_i.astype(np.float32),
'exp': x_d_i_info['exp'][i].reshape(1, 21, -1).cpu().numpy().astype(np.float32),
't': x_d_i_info['t'][i].cpu().numpy().astype(np.float32),
}
template_dct['motion'].append(item_dct)
# template_dct['c_d_eyes_lst'].append(x_d_i_info['c_eye'][i])
# template_dct['c_d_lip_lst'].append(x_d_i_info['c_lip'][i])
I_p_lst = []
R_d_0, x_d_0_info = None, None
for i in track(range(num_frame), description='Animating...', total=num_frame):
x_d_i_info = template_dct['motion'][i]
for key in x_d_i_info:
x_d_i_info[key] = torch.tensor(x_d_i_info[key]).cuda()
for key in x_s_info:
x_s_info[key] = torch.tensor(x_s_info[key]).cuda()
R_d_i = x_d_i_info['R_d']
if i == 0:
R_d_0 = R_d_i
x_d_0_info = x_d_i_info
if self.inf_cfg.infer_params.flag_relative_motion:
R_new = (R_d_i.cpu().numpy() @ R_d_0.permute(0, 2, 1).cpu().numpy()) @ R_s
delta_new = x_s_info['exp'].reshape(1, 21, -1) + (x_d_i_info['exp'] - x_d_0_info['exp'])
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale'])
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t'])
else:
R_new = R_d_i
delta_new = x_d_i_info['exp']
scale_new = x_s_info['scale']
t_new = x_d_i_info['t']
t_new[..., 2] = 0 # zero tz
x_c_s = torch.tensor(x_c_s, dtype=torch.float32).cuda()
R_new = torch.tensor(R_new, dtype=torch.float32).cuda()
delta_new = torch.tensor(delta_new, dtype=torch.float32).cuda()
t_new = torch.tensor(t_new, dtype=torch.float32).cuda()
scale_new = torch.tensor(scale_new, dtype=torch.float32).cuda()
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new
x_d_i_new = x_d_i_new.cpu().numpy()
# Algorithm 1:
if not self.inf_cfg.infer_params.flag_stitching and not self.inf_cfg.infer_params.flag_eye_retargeting and not self.inf_cfg.infer_params.flag_lip_retargeting:
# without stitching or retargeting
if flag_lip_zero:
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3)
else:
pass
elif self.inf_cfg.infer_params.flag_stitching and not self.inf_cfg.infer_params.flag_eye_retargeting and not self.inf_cfg.infer_params.flag_lip_retargeting:
# with stitching and without retargeting
if flag_lip_zero:
x_d_i_new = self.live_portrait_pipeline.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(
-1, x_s.shape[1], 3)
else:
x_d_i_new = self.live_portrait_pipeline.stitching(x_s, x_d_i_new)
else:
eyes_delta, lip_delta = None, None
if self.inf_cfg.infer_params.flag_eye_retargeting:
c_d_eyes_i = template_dct['c_d_eyes_lst'][i]
combined_eye_ratio_tensor = self.calc_combined_eye_ratio(c_d_eyes_i, c_s_eye)
# ∆_eyes,i = R_eyes(x_s; c_s,eyes, c_d,eyes,i)
eyes_delta = self.live_portrait_pipeline.retarget_eye(x_s, combined_eye_ratio_tensor)
if self.inf_cfg.infer_params.flag_lip_retargeting:
c_d_lip_i = template_dct['c_d_lip_lst'][i]
combined_lip_ratio_tensor = self.calc_combined_lip_ratio(c_d_lip_i, c_s_lip)
# ∆_lip,i = R_lip(x_s; c_s,lip, c_d,lip,i)
lip_delta = self.live_portrait_pipeline.retarget_lip(x_s, combined_lip_ratio_tensor)
if self.inf_cfg.infer_params.flag_relative_motion: # use x_s
x_d_i_new = x_s + \
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
else: # use x_d,i
x_d_i_new = x_d_i_new + \
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0)
if self.inf_cfg.infer_params.flag_stitching:
x_d_i_new = self.live_portrait_pipeline.stitching(x_s, x_d_i_new)
out = self.live_portrait_pipeline.model_dict["warping_spade"].predict(f_s, x_s, x_d_i_new).cpu().numpy().astype(np.uint8)
I_p_lst.append(out)
video_name = os.path.basename(save_path)
video_save_dir = os.path.dirname(save_path)
path = os.path.join(video_save_dir, video_name)
imageio.mimsave(path, I_p_lst, fps=float(25))
audio_name = audio_path.split('/')[-1]
new_audio_path = os.path.join(video_save_dir, audio_name)
start_time = 0
# cog will not keep the .mp3 filename
sound = AudioSegment.from_file(audio_path)
end_time = start_time + num_frame * 1 / 25 * 1000
word1 = sound.set_frame_rate(16000)
word = word1[start_time:end_time]
word.export(new_audio_path, format="wav")
save_video_with_watermark(path, new_audio_path, save_path, watermark=False)
print(f'The generated video is named {video_save_dir}/{video_name}')
print('#' * 25 + f'End Inference, cost time {time.time() - st}' + '#' * 25)
return save_path
import argparse
if __name__ == "__main__":
Infer = Inferencer()
Infer.generate_with_audio_img(None, 'difpoint/assets/test/test.wav', 'difpoint/assets/test/test2.jpg', 0.8, 0.8, 0.8, 0.8)