Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: UTF-8 -*- | |
''' | |
@File :inference.py | |
@Author :Chaolong Yang | |
@Date :2024/5/29 19:26 | |
''' | |
import glob | |
import os | |
os.environ['HYDRA_FULL_ERROR']='1' | |
import os | |
import time | |
import shutil | |
import uuid | |
import os | |
import cv2 | |
import tyro | |
import numpy as np | |
from tqdm import tqdm | |
import cv2 | |
from rich.progress import track | |
from difpoint.croper import Croper | |
from PIL import Image | |
import time | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
import imageio | |
from pydub import AudioSegment | |
from pykalman import KalmanFilter | |
import scipy | |
import matplotlib.pyplot as plt | |
import matplotlib | |
matplotlib.use('Agg') | |
from difpoint.dataset_process import audio | |
import os | |
import argparse | |
import pdb | |
import ffmpeg | |
import cv2 | |
import time | |
import numpy as np | |
import os | |
import datetime | |
import platform | |
from omegaconf import OmegaConf | |
#from difpoint.src.pipelines.faster_live_portrait_pipeline import FasterLivePortraitPipeline | |
from difpoint.src.live_portrait_pipeline import LivePortraitPipeline | |
from difpoint.src.config.argument_config import ArgumentConfig | |
from difpoint.src.config.inference_config import InferenceConfig | |
from difpoint.src.config.crop_config import CropConfig | |
from difpoint.src.live_portrait_pipeline import LivePortraitPipeline | |
from difpoint.src.utils.retargeting_utils import calc_eye_close_ratio, calc_lip_close_ratio | |
from difpoint.src.utils.camera import get_rotation_matrix | |
from difpoint.src.utils.video import images2video, concat_frames, get_fps, add_audio_to_video, has_audio_stream | |
FFMPEG = "ffmpeg" | |
def parse_audio_length(audio_length, sr, fps): | |
bit_per_frames = sr / fps | |
num_frames = int(audio_length / bit_per_frames) | |
audio_length = int(num_frames * bit_per_frames) | |
return audio_length, num_frames | |
def crop_pad_audio(wav, audio_length): | |
if len(wav) > audio_length: | |
wav = wav[:audio_length] | |
elif len(wav) < audio_length: | |
wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0) | |
return wav | |
class Conv2d(nn.Module): | |
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act=True, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.conv_block = nn.Sequential( | |
nn.Conv2d(cin, cout, kernel_size, stride, padding), | |
nn.BatchNorm2d(cout) | |
) | |
self.act = nn.ReLU() | |
self.residual = residual | |
self.use_act = use_act | |
def forward(self, x): | |
out = self.conv_block(x) | |
if self.residual: | |
out += x | |
if self.use_act: | |
return self.act(out) | |
else: | |
return out | |
class AudioEncoder(nn.Module): | |
def __init__(self, wav2lip_checkpoint, device): | |
super(AudioEncoder, self).__init__() | |
self.audio_encoder = nn.Sequential( | |
Conv2d(1, 32, kernel_size=3, stride=1, padding=1), | |
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), | |
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), | |
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), | |
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), | |
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), | |
Conv2d(64, 128, kernel_size=3, stride=3, padding=1), | |
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), | |
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), | |
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), | |
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), | |
Conv2d(256, 512, kernel_size=3, stride=1, padding=0), | |
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) | |
#### load the pre-trained audio_encoder | |
wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict'] | |
state_dict = self.audio_encoder.state_dict() | |
for k,v in wav2lip_state_dict.items(): | |
if 'audio_encoder' in k: | |
state_dict[k.replace('module.audio_encoder.', '')] = v | |
self.audio_encoder.load_state_dict(state_dict) | |
def forward(self, audio_sequences): | |
# audio_sequences = (B, T, 1, 80, 16) | |
B = audio_sequences.size(0) | |
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) | |
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 | |
dim = audio_embedding.shape[1] | |
audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1)) | |
return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512 | |
def partial_fields(target_class, kwargs): | |
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) | |
def dct2device(dct: dict, device): | |
for key in dct: | |
dct[key] = torch.tensor(dct[key]).to(device) | |
return dct | |
def save_video_with_watermark(video, audio, save_path, watermark=False): | |
temp_file = str(uuid.uuid4())+'.mp4' | |
cmd = r'ffmpeg -y -i "%s" -i "%s" -vcodec copy "%s"' % (video, audio, temp_file) | |
os.system(cmd) | |
shutil.move(temp_file, save_path) | |
class Inferencer(object): | |
def __init__(self): | |
st=time.time() | |
print('#'*25+'Start initialization'+'#'*25) | |
self.device = 'cuda' | |
from difpoint.model import get_model | |
self.point_diffusion = get_model() | |
ckpt = torch.load('./downloaded_repo/ckpts/KDTalker.pth', weights_only=False) | |
self.point_diffusion.load_state_dict(ckpt['model']) | |
print('model', self.point_diffusion.children()) | |
self.point_diffusion.eval() | |
self.point_diffusion.to(self.device) | |
lm_croper_checkpoint = './downloaded_repo/ckpts/shape_predictor_68_face_landmarks.dat' | |
self.croper = Croper(lm_croper_checkpoint) | |
self.norm_info = dict(np.load(r'difpoint/datasets/norm_info_d6.5_c8.5_vox1_train.npz')) | |
wav2lip_checkpoint = './downloaded_repo/ckpts/wav2lip.pth' | |
self.wav2lip_model = AudioEncoder(wav2lip_checkpoint, 'cuda') | |
self.wav2lip_model.cuda() | |
self.wav2lip_model.eval() | |
args = tyro.cli(ArgumentConfig) | |
self.inf_cfg = partial_fields(InferenceConfig, args.__dict__) # use attribute of args to initial InferenceConfig | |
self.crop_cfg = partial_fields(CropConfig, args.__dict__) # use attribute of args to initial CropConfig | |
self.live_portrait_pipeline = LivePortraitPipeline(inference_cfg=self.inf_cfg, crop_cfg=self.crop_cfg) | |
print('#'*25+f'End initialization, cost time {time.time()-st}'+'#'*25) | |
def _norm(self, data_dict): | |
for k in data_dict.keys(): | |
if k in ['yaw', 'pitch', 'roll', 't', 'exp', 'scale', 'kp', 'c_lip', 'c_eye']: | |
v=data_dict[k] | |
data_dict[k] = (v - self.norm_info[k+'_mean'])/self.norm_info[k+'_std'] | |
return data_dict | |
def _denorm(self, data_dict): | |
for k in data_dict.keys(): | |
if k in ['yaw', 'pitch', 'roll', 't', 'exp', 'scale', 'kp', 'c_lip', 'c_eye']: | |
v=data_dict[k] | |
data_dict[k] = v * self.norm_info[k+'_std'] + self.norm_info[k+'_mean'] | |
return data_dict | |
def output_to_dict(self, data): | |
output = {} | |
output['scale'] = data[:, 0] | |
output['yaw'] = data[:, 1, None] | |
output['pitch'] = data[:, 2, None] | |
output['roll'] = data[:, 3, None] | |
output['t'] = data[:, 4:7] | |
output['exp'] = data[:, 7:] | |
return output | |
def extract_mel_from_audio(self, audio_file_path): | |
syncnet_mel_step_size = 16 | |
fps = 25 | |
wav = audio.load_wav(audio_file_path, 16000) | |
wav_length, num_frames = parse_audio_length(len(wav), 16000, 25) | |
wav = crop_pad_audio(wav, wav_length) | |
orig_mel = audio.melspectrogram(wav).T | |
spec = orig_mel.copy() | |
indiv_mels = [] | |
for i in tqdm(range(num_frames), 'mel:'): | |
start_frame_num = i - 2 | |
start_idx = int(80. * (start_frame_num / float(fps))) | |
end_idx = start_idx + syncnet_mel_step_size | |
seq = list(range(start_idx, end_idx)) | |
seq = [min(max(item, 0), orig_mel.shape[0] - 1) for item in seq] | |
m = spec[seq, :] | |
indiv_mels.append(m.T) | |
indiv_mels = np.asarray(indiv_mels) # T 80 16 | |
return indiv_mels | |
def extract_wav2lip_from_audio(self, audio_file_path): | |
asd_mel = self.extract_mel_from_audio(audio_file_path) | |
asd_mel = torch.FloatTensor(asd_mel).cuda().unsqueeze(0).unsqueeze(2) | |
with torch.no_grad(): | |
hidden = self.wav2lip_model(asd_mel) | |
return hidden[0].cpu().detach().numpy() | |
def headpose_pred_to_degree(self, pred): | |
device = pred.device | |
idx_tensor = [idx for idx in range(66)] | |
idx_tensor = torch.FloatTensor(idx_tensor).to(device) | |
pred = F.softmax(pred) | |
degree = torch.sum(pred * idx_tensor, 1) * 3 - 99 | |
return degree | |
def calc_combined_eye_ratio(self, c_d_eyes_i, c_s_eyes): | |
c_s_eyes_tensor = torch.from_numpy(c_s_eyes).float().to(self.device) | |
c_d_eyes_i_tensor = c_d_eyes_i[0].reshape(1, 1).to(self.device) | |
# [c_s,eyes, c_d,eyes,i] | |
combined_eye_ratio_tensor = torch.cat([c_s_eyes_tensor, c_d_eyes_i_tensor], dim=1) | |
return combined_eye_ratio_tensor | |
def calc_combined_lip_ratio(self, c_d_lip_i, c_s_lip): | |
c_s_lip_tensor = torch.from_numpy(c_s_lip).float().to(self.device) | |
c_d_lip_i_tensor = c_d_lip_i[0].to(self.device).reshape(1, 1) # 1x1 | |
# [c_s,lip, c_d,lip,i] | |
combined_lip_ratio_tensor = torch.cat([c_s_lip_tensor, c_d_lip_i_tensor], dim=1) # 1x2 | |
return combined_lip_ratio_tensor | |
# 2024.06.26 | |
def generate_with_audio_img(self, upload_audio_path, tts_audio_path, audio_type, image_path, smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t, save_path='./downloaded_repo/'): | |
print(audio_type) | |
if audio_type == 'upload': | |
audio_path = upload_audio_path | |
elif audio_type == 'tts': | |
audio_path = tts_audio_path | |
save_path = os.path.join(save_path, "output.mp4") | |
image = [np.array(Image.open(image_path).convert('RGB'))] | |
if image[0].shape[0] != 256 or image[0].shape[1] != 256: | |
cropped_image, crop, quad = self.croper.crop(image, still=False, xsize=512) | |
input_image = cv2.resize(cropped_image[0], (256, 256)) | |
else: | |
input_image = image[0] | |
I_s = torch.FloatTensor(input_image.transpose((2, 0, 1))).unsqueeze(0).cuda() / 255 | |
x_s_info = self.live_portrait_pipeline.live_portrait_wrapper.get_kp_info(I_s) | |
x_c_s = x_s_info['kp'].reshape(1, 21, -1) | |
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) | |
f_s = self.live_portrait_pipeline.live_portrait_wrapper.extract_feature_3d(I_s) | |
x_s = self.live_portrait_pipeline.live_portrait_wrapper.transform_keypoint(x_s_info) | |
flag_lip_zero = self.inf_cfg.flag_lip_zero # not overwrite | |
######## process driving info ######## | |
kp_info = {} | |
for k in x_s_info.keys(): | |
kp_info[k] = x_s_info[k].cpu().numpy() | |
# kp_info['c_lip'] = c_s_lip | |
# kp_info['c_eye'] = c_s_eye | |
kp_info = self._norm(kp_info) | |
ori_kp = torch.cat([torch.zeros([1, 7]), torch.Tensor(kp_info['kp'])], -1).cuda() | |
input_x = np.concatenate([kp_info[k] for k in ['scale', 'yaw', 'pitch', 'roll', 't']], 1) | |
input_x = np.concatenate((input_x, kp_info['exp'].reshape(1, 63)), axis=1) | |
input_x = np.expand_dims(input_x, -1) | |
input_x = np.expand_dims(input_x, 0) | |
input_x = np.concatenate([input_x, input_x, input_x], -1) | |
aud_feat = self.extract_wav2lip_from_audio(audio_path) | |
outputs = [input_x] | |
st = time.time() | |
print('#' * 25 + 'Start Inference' + '#' * 25) | |
sample_frame = 64 # 32 aud_feat.shape[0] | |
for i in range(0, aud_feat.shape[0] - 1, sample_frame): | |
input_mel = torch.Tensor(aud_feat[i: i + sample_frame]).unsqueeze(0).cuda() | |
kp0 = torch.Tensor(outputs[-1])[:, -1].cuda() | |
pred_kp = self.point_diffusion.forward_sample(70, ref_kps=kp0, ori_kps=ori_kp, aud_feat=input_mel, | |
scheduler='ddim', num_inference_steps=50) | |
outputs.append(pred_kp.cpu().numpy()) | |
outputs = np.mean(np.concatenate(outputs, 1)[0], -1)[1:, ] | |
output_dict = self.output_to_dict(outputs) | |
output_dict = self._denorm(output_dict) | |
num_frame = output_dict['yaw'].shape[0] | |
x_d_info = {} | |
for key in output_dict: | |
x_d_info[key] = torch.tensor(output_dict[key]).cuda() | |
# smooth | |
def smooth(sequence, n_dim_state=1): | |
kf = KalmanFilter(initial_state_mean=sequence[0], | |
transition_covariance=0.05 * np.eye(n_dim_state), # 较小的过程噪声 | |
observation_covariance=0.001 * np.eye(n_dim_state)) # 可以增大观测噪声,减少敏感性 | |
state_means, _ = kf.smooth(sequence) | |
return state_means | |
# scale_data = x_d_info['scale'].cpu().numpy() | |
yaw_data = x_d_info['yaw'].cpu().numpy() | |
pitch_data = x_d_info['pitch'].cpu().numpy() | |
roll_data = x_d_info['roll'].cpu().numpy() | |
t_data = x_d_info['t'].cpu().numpy() | |
exp_data = x_d_info['exp'].cpu().numpy() | |
smoothed_pitch = smooth(pitch_data, n_dim_state=1) * smoothed_pitch | |
smoothed_yaw = smooth(yaw_data, n_dim_state=1) * smoothed_yaw | |
smoothed_roll = smooth(roll_data, n_dim_state=1) * smoothed_roll | |
# smoothed_scale = smooth(scale_data, n_dim_state=1) | |
smoothed_t = smooth(t_data, n_dim_state=3) * smoothed_t | |
smoothed_exp = smooth(exp_data, n_dim_state=63) | |
# x_d_info['scale'] = torch.Tensor(smoothed_scale).cuda() | |
x_d_info['pitch'] = torch.Tensor(smoothed_pitch).cuda() | |
x_d_info['yaw'] = torch.Tensor(smoothed_yaw).cuda() | |
x_d_info['roll'] = torch.Tensor(smoothed_roll).cuda() | |
x_d_info['t'] = torch.Tensor(smoothed_t).cuda() | |
x_d_info['exp'] = torch.Tensor(smoothed_exp).cuda() | |
template_dct = {'motion': [], 'c_d_eyes_lst': [], 'c_d_lip_lst': []} | |
for i in track(range(num_frame), description='Making motion templates...', total=num_frame): | |
# collect s_d, R_d, δ_d and t_d for inference | |
x_d_i_info = x_d_info | |
R_d_i = get_rotation_matrix(x_d_i_info['pitch'][i], x_d_i_info['yaw'][i], x_d_i_info['roll'][i]) | |
item_dct = { | |
'scale': x_d_i_info['scale'][i].cpu().numpy().astype(np.float32), | |
'R_d': R_d_i.cpu().numpy().astype(np.float32), | |
'exp': x_d_i_info['exp'][i].reshape(1, 21, -1).cpu().numpy().astype(np.float32), | |
't': x_d_i_info['t'][i].cpu().numpy().astype(np.float32), | |
} | |
template_dct['motion'].append(item_dct) | |
# template_dct['c_d_eyes_lst'].append(x_d_i_info['c_eye'][i]) | |
# template_dct['c_d_lip_lst'].append(x_d_i_info['c_lip'][i]) | |
I_p_lst = [] | |
R_d_0, x_d_0_info = None, None | |
for i in track(range(num_frame), description='Animating...', total=num_frame): | |
x_d_i_info = template_dct['motion'][i] | |
for key in x_d_i_info: | |
x_d_i_info[key] = torch.tensor(x_d_i_info[key]).cuda() | |
R_d_i = x_d_i_info['R_d'] | |
if i == 0: | |
R_d_0 = R_d_i | |
x_d_0_info = x_d_i_info | |
if self.inf_cfg.flag_relative_motion: | |
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s | |
delta_new = x_s_info['exp'].reshape(1, 21, -1) + (x_d_i_info['exp'] - x_d_0_info['exp']) | |
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) | |
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) | |
else: | |
R_new = R_d_i | |
delta_new = x_d_i_info['exp'] | |
scale_new = x_s_info['scale'] | |
t_new = x_d_i_info['t'] | |
t_new[..., 2] = 0 # zero tz | |
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new | |
# Algorithm 1: | |
if not self.inf_cfg.flag_stitching and not self.inf_cfg.flag_eye_retargeting and not self.inf_cfg.flag_lip_retargeting: | |
# without stitching or retargeting | |
if flag_lip_zero: | |
x_d_i_new += lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) | |
else: | |
pass | |
elif self.inf_cfg.flag_stitching and not self.inf_cfg.flag_eye_retargeting and not self.inf_cfg.flag_lip_retargeting: | |
# with stitching and without retargeting | |
if flag_lip_zero: | |
x_d_i_new = self.live_portrait_pipeline.live_portrait_wrapper.stitching(x_s, x_d_i_new) + lip_delta_before_animation.reshape(-1, x_s.shape[1], 3) | |
else: | |
x_d_i_new = self.live_portrait_pipeline.live_portrait_wrapper.stitching(x_s, x_d_i_new) | |
else: | |
eyes_delta, lip_delta = None, None | |
if self.inf_cfg.flag_relative_motion: # use x_s | |
x_d_i_new = x_s + \ | |
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ | |
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) | |
else: # use x_d,i | |
x_d_i_new = x_d_i_new + \ | |
(eyes_delta.reshape(-1, x_s.shape[1], 3) if eyes_delta is not None else 0) + \ | |
(lip_delta.reshape(-1, x_s.shape[1], 3) if lip_delta is not None else 0) | |
if self.inf_cfg.flag_stitching: | |
x_d_i_new = self.live_portrait_pipeline.live_portrait_wrapper.stitching(x_s, x_d_i_new) | |
out = self.live_portrait_pipeline.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new) | |
I_p_i = self.live_portrait_pipeline.live_portrait_wrapper.parse_output(out['out'])[0] | |
I_p_lst.append(I_p_i) | |
video_name = os.path.basename(save_path) | |
video_save_dir = os.path.dirname(save_path) | |
path = os.path.join(video_save_dir, video_name) | |
imageio.mimsave(path, I_p_lst, fps=float(25)) | |
audio_name = audio_path.split('/')[-1] | |
new_audio_path = os.path.join(video_save_dir, audio_name) | |
start_time = 0 | |
# cog will not keep the .mp3 filename | |
sound = AudioSegment.from_file(audio_path) | |
end_time = start_time + num_frame * 1 / 25 * 1000 | |
word1 = sound.set_frame_rate(16000) | |
word = word1[start_time:end_time] | |
word.export(new_audio_path, format="wav") | |
save_video_with_watermark(path, new_audio_path, save_path, watermark=False) | |
print(f'The generated video is named {video_save_dir}/{video_name}') | |
print('#' * 25 + f'End Inference, cost time {time.time() - st}' + '#' * 25) | |
return save_path | |