|
|
|
import os |
|
os.environ['HYDRA_FULL_ERROR']='1' |
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
snapshot_download( |
|
repo_id = "ChaolongYang/KDTalker", |
|
local_dir = "./" |
|
) |
|
|
|
import argparse |
|
import shutil |
|
import uuid |
|
import os |
|
import numpy as np |
|
from tqdm import tqdm |
|
import cv2 |
|
from rich.progress import track |
|
import tyro |
|
|
|
import gradio as gr |
|
from PIL import Image |
|
import time |
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
import imageio |
|
from pydub import AudioSegment |
|
from pykalman import KalmanFilter |
|
|
|
|
|
from src.config.argument_config import ArgumentConfig |
|
from src.config.inference_config import InferenceConfig |
|
from src.config.crop_config import CropConfig |
|
from src.live_portrait_pipeline import LivePortraitPipeline |
|
from src.utils.camera import get_rotation_matrix |
|
from dataset_process import audio |
|
|
|
from dataset_process.croper import Croper |
|
|
|
import spaces |
|
|
|
def parse_audio_length(audio_length, sr, fps): |
|
bit_per_frames = sr / fps |
|
num_frames = int(audio_length / bit_per_frames) |
|
audio_length = int(num_frames * bit_per_frames) |
|
return audio_length, num_frames |
|
|
|
def crop_pad_audio(wav, audio_length): |
|
if len(wav) > audio_length: |
|
wav = wav[:audio_length] |
|
elif len(wav) < audio_length: |
|
wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0) |
|
return wav |
|
|
|
class Conv2d(nn.Module): |
|
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act=True, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.conv_block = nn.Sequential( |
|
nn.Conv2d(cin, cout, kernel_size, stride, padding), |
|
nn.BatchNorm2d(cout) |
|
) |
|
self.act = nn.ReLU() |
|
self.residual = residual |
|
self.use_act = use_act |
|
|
|
def forward(self, x): |
|
out = self.conv_block(x) |
|
if self.residual: |
|
out += x |
|
|
|
if self.use_act: |
|
return self.act(out) |
|
else: |
|
return out |
|
|
|
class AudioEncoder(nn.Module): |
|
def __init__(self, wav2lip_checkpoint, device): |
|
super(AudioEncoder, self).__init__() |
|
|
|
self.audio_encoder = nn.Sequential( |
|
Conv2d(1, 32, kernel_size=3, stride=1, padding=1), |
|
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), |
|
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), |
|
|
|
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), |
|
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), |
|
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), |
|
|
|
Conv2d(64, 128, kernel_size=3, stride=3, padding=1), |
|
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), |
|
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), |
|
|
|
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), |
|
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), |
|
|
|
Conv2d(256, 512, kernel_size=3, stride=1, padding=0), |
|
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) |
|
|
|
|
|
wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict'] |
|
state_dict = self.audio_encoder.state_dict() |
|
|
|
for k,v in wav2lip_state_dict.items(): |
|
if 'audio_encoder' in k: |
|
state_dict[k.replace('module.audio_encoder.', '')] = v |
|
self.audio_encoder.load_state_dict(state_dict) |
|
|
|
def forward(self, audio_sequences): |
|
B = audio_sequences.size(0) |
|
|
|
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) |
|
|
|
audio_embedding = self.audio_encoder(audio_sequences) |
|
dim = audio_embedding.shape[1] |
|
audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1)) |
|
|
|
return audio_embedding.squeeze(-1).squeeze(-1) |
|
|
|
def partial_fields(target_class, kwargs): |
|
return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) |
|
|
|
def dct2device(dct: dict, device): |
|
for key in dct: |
|
dct[key] = torch.tensor(dct[key]).to(device) |
|
return dct |
|
|
|
def save_video_with_watermark(video, audio, save_path): |
|
temp_file = str(uuid.uuid4())+'.mp4' |
|
cmd = r'ffmpeg -y -i "%s" -i "%s" -vcodec copy "%s"' % (video, audio, temp_file) |
|
os.system(cmd) |
|
shutil.move(temp_file, save_path) |
|
|
|
class Inferencer(object): |
|
def __init__(self): |
|
st=time.time() |
|
print('#'*25+'Start initialization'+'#'*25) |
|
self.device = 'cuda' |
|
|
|
from model import get_model |
|
self.point_diffusion = get_model() |
|
ckpt = torch.load('ckpts/KDTalker.pth') |
|
|
|
self.point_diffusion.load_state_dict(ckpt['model']) |
|
self.point_diffusion.eval() |
|
self.point_diffusion.to(self.device) |
|
|
|
lm_croper_checkpoint = 'ckpts/shape_predictor_68_face_landmarks.dat' |
|
self.croper = Croper(lm_croper_checkpoint) |
|
|
|
self.norm_info = dict(np.load('dataset_process/norm.npz')) |
|
|
|
wav2lip_checkpoint = 'ckpts/wav2lip.pth' |
|
self.wav2lip_model = AudioEncoder(wav2lip_checkpoint, 'cuda') |
|
self.wav2lip_model.cuda() |
|
self.wav2lip_model.eval() |
|
|
|
|
|
tyro.extras.set_accent_color("bright_cyan") |
|
args = tyro.cli(ArgumentConfig) |
|
|
|
|
|
self.inf_cfg = partial_fields(InferenceConfig, args.__dict__) |
|
self.crop_cfg = partial_fields(CropConfig, args.__dict__) |
|
|
|
self.live_portrait_pipeline = LivePortraitPipeline(inference_cfg=self.inf_cfg, crop_cfg=self.crop_cfg) |
|
|
|
def _norm(self, data_dict): |
|
for k in data_dict.keys(): |
|
if k in ['yaw', 'pitch', 'roll', 't', 'exp', 'scale', 'kp', ]: |
|
v=data_dict[k] |
|
data_dict[k] = (v - self.norm_info[k+'_mean'])/self.norm_info[k+'_std'] |
|
return data_dict |
|
|
|
def _denorm(self, data_dict): |
|
for k in data_dict.keys(): |
|
if k in ['yaw', 'pitch', 'roll', 't', 'exp', 'scale', 'kp']: |
|
v=data_dict[k] |
|
data_dict[k] = v * self.norm_info[k+'_std'] + self.norm_info[k+'_mean'] |
|
return data_dict |
|
|
|
def output_to_dict(self, data): |
|
output = {} |
|
output['scale'] = data[:, 0] |
|
output['yaw'] = data[:, 1, None] |
|
output['pitch'] = data[:, 2, None] |
|
output['roll'] = data[:, 3, None] |
|
output['t'] = data[:, 4:7] |
|
output['exp'] = data[:, 7:] |
|
return output |
|
|
|
def extract_mel_from_audio(self, audio_file_path): |
|
syncnet_mel_step_size = 16 |
|
fps = 25 |
|
wav = audio.load_wav(audio_file_path, 16000) |
|
wav_length, num_frames = parse_audio_length(len(wav), 16000, 25) |
|
wav = crop_pad_audio(wav, wav_length) |
|
orig_mel = audio.melspectrogram(wav).T |
|
spec = orig_mel.copy() |
|
indiv_mels = [] |
|
|
|
for i in tqdm(range(num_frames), 'mel:'): |
|
start_frame_num = i - 2 |
|
start_idx = int(80. * (start_frame_num / float(fps))) |
|
end_idx = start_idx + syncnet_mel_step_size |
|
seq = list(range(start_idx, end_idx)) |
|
seq = [min(max(item, 0), orig_mel.shape[0] - 1) for item in seq] |
|
m = spec[seq, :] |
|
indiv_mels.append(m.T) |
|
indiv_mels = np.asarray(indiv_mels) |
|
return indiv_mels |
|
|
|
def extract_wav2lip_from_audio(self, audio_file_path): |
|
asd_mel = self.extract_mel_from_audio(audio_file_path) |
|
asd_mel = torch.FloatTensor(asd_mel).cuda().unsqueeze(0).unsqueeze(2) |
|
with torch.no_grad(): |
|
hidden = self.wav2lip_model(asd_mel) |
|
return hidden[0].cpu().detach().numpy() |
|
|
|
def headpose_pred_to_degree(self, pred): |
|
device = pred.device |
|
idx_tensor = [idx for idx in range(66)] |
|
idx_tensor = torch.FloatTensor(idx_tensor).to(device) |
|
pred = F.softmax(pred) |
|
degree = torch.sum(pred * idx_tensor, 1) * 3 - 99 |
|
return degree |
|
|
|
@torch.no_grad() |
|
def generate_with_audio_img(self, image_path, audio_path, save_path): |
|
image = np.array(Image.open(image_path).convert('RGB')) |
|
cropped_image, crop, quad = self.croper.crop([image], still=False, xsize=512) |
|
input_image = cv2.resize(cropped_image[0], (256, 256)) |
|
|
|
I_s = torch.FloatTensor(input_image.transpose((2, 0, 1))).unsqueeze(0).cuda() / 255 |
|
|
|
x_s_info = self.live_portrait_pipeline.live_portrait_wrapper.get_kp_info(I_s) |
|
x_c_s = x_s_info['kp'].reshape(1, 21, -1) |
|
R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) |
|
f_s = self.live_portrait_pipeline.live_portrait_wrapper.extract_feature_3d(I_s) |
|
x_s = self.live_portrait_pipeline.live_portrait_wrapper.transform_keypoint(x_s_info) |
|
|
|
|
|
kp_info = {} |
|
for k in x_s_info.keys(): |
|
kp_info[k] = x_s_info[k].cpu().numpy() |
|
|
|
kp_info = self._norm(kp_info) |
|
|
|
ori_kp = torch.cat([torch.zeros([1, 7]), torch.Tensor(kp_info['kp'])], -1).cuda() |
|
|
|
input_x = np.concatenate([kp_info[k] for k in ['scale', 'yaw', 'pitch', 'roll', 't', 'exp']], 1) |
|
input_x = np.expand_dims(input_x, -1) |
|
input_x = np.expand_dims(input_x, 0) |
|
input_x = np.concatenate([input_x, input_x, input_x], -1) |
|
|
|
aud_feat = self.extract_wav2lip_from_audio(audio_path) |
|
|
|
sample_frame = 64 |
|
padding_size = (sample_frame - aud_feat.shape[0] % sample_frame) % sample_frame |
|
|
|
if padding_size > 0: |
|
aud_feat = np.concatenate((aud_feat, aud_feat[:padding_size, :]), axis=0) |
|
else: |
|
aud_feat = aud_feat |
|
|
|
outputs = [input_x] |
|
|
|
sample_frame = 64 |
|
for i in range(0, aud_feat.shape[0] - 1, sample_frame): |
|
input_mel = torch.Tensor(aud_feat[i: i + sample_frame]).unsqueeze(0).cuda() |
|
kp0 = torch.Tensor(outputs[-1])[:, -1].cuda() |
|
pred_kp = self.point_diffusion.forward_sample(70, ref_kps=kp0, ori_kps=ori_kp, aud_feat=input_mel, |
|
scheduler='ddim', num_inference_steps=50) |
|
outputs.append(pred_kp.cpu().numpy()) |
|
|
|
outputs = np.mean(np.concatenate(outputs, 1)[0, 1:aud_feat.shape[0] - padding_size + 1], -1) |
|
output_dict = self.output_to_dict(outputs) |
|
output_dict = self._denorm(output_dict) |
|
|
|
num_frame = output_dict['yaw'].shape[0] |
|
x_d_info = {} |
|
for key in output_dict: |
|
x_d_info[key] = torch.tensor(output_dict[key]).cuda() |
|
|
|
|
|
def smooth(sequence, n_dim_state=1): |
|
kf = KalmanFilter(initial_state_mean=sequence[0], |
|
transition_covariance=0.05 * np.eye(n_dim_state), |
|
observation_covariance=0.001 * np.eye(n_dim_state)) |
|
state_means, _ = kf.smooth(sequence) |
|
return state_means |
|
|
|
yaw_data = x_d_info['yaw'].cpu().numpy() |
|
pitch_data = x_d_info['pitch'].cpu().numpy() |
|
roll_data = x_d_info['roll'].cpu().numpy() |
|
t_data = x_d_info['t'].cpu().numpy() |
|
exp_data = x_d_info['exp'].cpu().numpy() |
|
|
|
smoothed_pitch = smooth(pitch_data, n_dim_state=1) |
|
smoothed_yaw = smooth(yaw_data, n_dim_state=1) |
|
smoothed_roll = smooth(roll_data, n_dim_state=1) |
|
smoothed_t = smooth(t_data, n_dim_state=3) |
|
smoothed_exp = smooth(exp_data, n_dim_state=63) |
|
|
|
x_d_info['pitch'] = torch.Tensor(smoothed_pitch).cuda() |
|
x_d_info['yaw'] = torch.Tensor(smoothed_yaw).cuda() |
|
x_d_info['roll'] = torch.Tensor(smoothed_roll).cuda() |
|
x_d_info['t'] = torch.Tensor(smoothed_t).cuda() |
|
x_d_info['exp'] = torch.Tensor(smoothed_exp).cuda() |
|
|
|
template_dct = {'motion': [], 'c_d_eyes_lst': [], 'c_d_lip_lst': []} |
|
for i in track(range(num_frame), description='Making motion templates...', total=num_frame): |
|
x_d_i_info = x_d_info |
|
R_d_i = get_rotation_matrix(x_d_i_info['pitch'][i], x_d_i_info['yaw'][i], x_d_i_info['roll'][i]) |
|
|
|
item_dct = { |
|
'scale': x_d_i_info['scale'][i].cpu().numpy().astype(np.float32), |
|
'R_d': R_d_i.cpu().numpy().astype(np.float32), |
|
'exp': x_d_i_info['exp'][i].reshape(1, 21, -1).cpu().numpy().astype(np.float32), |
|
't': x_d_i_info['t'][i].cpu().numpy().astype(np.float32), |
|
} |
|
|
|
template_dct['motion'].append(item_dct) |
|
|
|
I_p_lst = [] |
|
R_d_0, x_d_0_info = None, None |
|
|
|
for i in track(range(num_frame), description='🚀Animating...', total=num_frame): |
|
x_d_i_info = template_dct['motion'][i] |
|
for key in x_d_i_info: |
|
x_d_i_info[key] = torch.tensor(x_d_i_info[key]).cuda() |
|
R_d_i = x_d_i_info['R_d'] |
|
|
|
if i == 0: |
|
R_d_0 = R_d_i |
|
x_d_0_info = x_d_i_info |
|
|
|
if self.inf_cfg.flag_relative_motion: |
|
R_new = (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s |
|
delta_new = x_s_info['exp'].reshape(1, 21, -1) + (x_d_i_info['exp'] - x_d_0_info['exp']) |
|
scale_new = x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) |
|
t_new = x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) |
|
else: |
|
R_new = R_d_i |
|
delta_new = x_d_i_info['exp'] |
|
scale_new = x_s_info['scale'] |
|
t_new = x_d_i_info['t'] |
|
|
|
t_new[..., 2].fill_(0) |
|
x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new |
|
|
|
out = self.live_portrait_pipeline.live_portrait_wrapper.warp_decode(f_s, x_s, x_d_i_new) |
|
I_p_i = self.live_portrait_pipeline.live_portrait_wrapper.parse_output(out['out'])[0] |
|
I_p_lst.append(I_p_i) |
|
|
|
video_name = save_path.split('/')[-1] |
|
video_save_dir = os.path.dirname(save_path) |
|
path = os.path.join(video_save_dir, 'temp_' + video_name) |
|
|
|
imageio.mimsave(path, I_p_lst, fps=float(25)) |
|
|
|
audio_name = audio_path.split('/')[-1] |
|
new_audio_path = os.path.join(video_save_dir, audio_name) |
|
start_time = 0 |
|
sound = AudioSegment.from_file(audio_path) |
|
end_time = start_time + num_frame * 1 / 25 * 1000 |
|
word1 = sound.set_frame_rate(16000) |
|
word = word1[start_time:end_time] |
|
word.export(new_audio_path, format="wav") |
|
|
|
save_video_with_watermark(path, new_audio_path, save_path) |
|
print(f'The generated video is named {video_save_dir}/{video_name}') |
|
|
|
os.remove(path) |
|
os.remove(new_audio_path) |
|
|
|
@spaces.GPU() |
|
def gradio_infer(source_image, driven_audio): |
|
|
|
import tempfile |
|
temp_dir = tempfile.mkdtemp() |
|
output_path = f"{temp_dir}/output.mp4" |
|
|
|
Infer = Inferencer() |
|
Infer.generate_with_audio_img(source_image, driven_audio, output_path) |
|
|
|
return output_path |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
gr.Markdown("# KDTalker") |
|
gr.Markdown("Unlock Pose Diversity: Accurate and Efficient Implicit Keypoint-based Spatiotemporal Diffusion for Audio-driven Talking Portrait") |
|
gr.HTML(""" |
|
<div style="display:flex;column-gap:4px;"> |
|
<a href="https://github.com/chaolongy/KDTalker"> |
|
<img src='https://img.shields.io/badge/GitHub-Repo-blue'> |
|
</a> |
|
<a href="https://arxiv.org/abs/2503.12963"> |
|
<img src='https://img.shields.io/badge/ArXiv-Paper-red'> |
|
</a> |
|
<a href="https://huggingface.co/spaces/fffiloni/KDTalker?duplicate=true"> |
|
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space"> |
|
</a> |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
source_image = gr.Image(label="Source Image", type="filepath") |
|
driven_audio = gr.Audio(label="Driven Audio", type="filepath") |
|
submit_btn = gr.Button("Submit") |
|
|
|
gr.Examples( |
|
examples = [ |
|
["example/source_image/WDA_BenCardin1_000.png", "example/audio_driven/WDA_BenCardin1_000.wav"], |
|
|
|
], |
|
inputs = [source_image, driven_audio], |
|
cache_examples = False |
|
) |
|
|
|
with gr.Column(): |
|
output_video = gr.Video(label="Output Video") |
|
|
|
submit_btn.click( |
|
fn = gradio_infer, |
|
inputs = [source_image, driven_audio], |
|
outputs = [output_video] |
|
) |
|
|
|
demo.launch() |
|
|