Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,343 Bytes
475d332 af9e4bd 475d332 af9e4bd 475d332 af9e4bd 475d332 af9e4bd 475d332 af9e4bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os, sys
import gradio as gr
from difpoint.inference import Inferencer
from TTS.api import TTS
import torch
import time
from flask import send_from_directory
from huggingface_hub import snapshot_download
import spaces
import tensorrt
import multiprocessing as mp
import pickle
mp.set_start_method('spawn', force=True)
repo_id = "ChaolongYang/KDTalker"
local_dir = "./downloaded_repo"
snapshot_download(repo_id=repo_id, local_dir=local_dir)
print("\nFiles downloaded:")
for root, dirs, files in os.walk(local_dir):
for file in files:
file_path = os.path.join(root, file)
print(file_path)
result_dir = "results"
def set_upload():
return "upload"
def set_microphone():
return "microphone"
def set_tts():
return "tts"
def create_kd_talker():
return Inferencer()
example_folder = "example"
example_choices = ["Example 1", "Example 2", "Example 3"]
example_mapping = {
"Example 1": {"audio": os.path.join(example_folder, "example1.wav"), "image": os.path.join(example_folder, "example1.png")},
"Example 2": {"audio": os.path.join(example_folder, "example2.wav"), "image": os.path.join(example_folder, "example2.png")},
"Example 3": {"audio": os.path.join(example_folder, "example3.wav"), "image": os.path.join(example_folder, "example3.png")},
}
@spaces.GPU
def predict(prompt, upload_reference_audio, microphone_reference_audio, reference_audio_type):
global result_dir
output_file_path = os.path.join('./downloaded_repo/', 'output.wav')
if reference_audio_type == 'upload':
audio_file_pth = upload_reference_audio
elif reference_audio_type == 'microphone':
audio_file_pth = microphone_reference_audio
tts = TTS('tts_models/multilingual/multi-dataset/your_tts')
tts.tts_to_file(
text=prompt,
file_path=output_file_path,
speaker_wav=audio_file_pth,
language="en",
)
return gr.Audio(value=output_file_path, type='filepath')
@spaces.GPU
def generate(upload_driven_audio, tts_driven_audio, driven_audio_type, source_image, smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t):
return Inferencer().generate_with_audio_img(upload_driven_audio, tts_driven_audio, driven_audio_type, source_image,
smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t)
def main():
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
with gr.Blocks(analytics_enabled=False) as interface:
with gr.Row():
gr.HTML(
"""
<div align='center'>
<h2> Unlock Pose Diversity: Accurate and Efficient Implicit Keypoint-based Spatiotemporal Diffusion for Audio-driven Talking Portrait </h2>
<div style="display: flex; justify-content: center; align-items: center; gap: 20px;">
<img src='https://newstatic.dukekunshan.edu.cn/mainsite/2021/08/07161629/large_dku-Logo-e1649298929570.png' alt='Logo' width='150'/>
<img src='https://www.xjtlu.edu.cn/wp-content/uploads/2023/12/7c52fd62e9cf26cb493faa7f91c2782.png' width='250'/>
</div>
</div>
"""
)
driven_audio_type = gr.Textbox(value="upload", visible=False)
reference_audio_type = gr.Textbox(value="upload", visible=False)
with gr.Row():
with gr.Column(variant="panel"):
with gr.Tabs(elem_id="kdtalker_source_image"):
with gr.TabItem("Upload image"):
source_image = gr.Image(label="Source image", sources="upload", type="filepath", scale=256)
with gr.Tabs(elem_id="kdtalker_driven_audio"):
with gr.TabItem("Upload"):
upload_driven_audio = gr.Audio(label="Upload audio", sources="upload", type="filepath")
upload_driven_audio.change(set_upload, outputs=driven_audio_type)
with gr.TabItem("TTS"):
upload_reference_audio = gr.Audio(label="Upload Reference Audio", sources="upload", type="filepath")
upload_reference_audio.change(set_upload, outputs=reference_audio_type)
microphone_reference_audio = gr.Audio(label="Recorded Reference Audio", sources="microphone", type="filepath")
microphone_reference_audio.change(set_microphone, outputs=reference_audio_type)
input_text = gr.Textbox(
label="Generating audio from text",
lines=5,
placeholder="please enter some text here, we generate the audio from text using @Coqui.ai TTS."
)
tts_button = gr.Button("Generate audio", elem_id="kdtalker_audio_generate", variant="primary")
tts_driven_audio = gr.Audio(label="Synthesised Audio", type="filepath")
tts_button.click(fn=predict, inputs=[input_text, upload_reference_audio, microphone_reference_audio, reference_audio_type], outputs=[tts_driven_audio])
tts_button.click(set_tts, outputs=driven_audio_type)
with gr.Column(variant="panel"):
gen_video = gr.Video(label="Generated video", format="mp4", width=256)
with gr.Tabs(elem_id="talker_checkbox"):
with gr.TabItem("KDTalker"):
smoothed_pitch = gr.Slider(minimum=0, maximum=1, step=0.1, label="Pitch", value=0.8)
smoothed_yaw = gr.Slider(minimum=0, maximum=1, step=0.1, label="Yaw", value=0.8)
smoothed_roll = gr.Slider(minimum=0, maximum=1, step=0.1, label="Roll", value=0.8)
smoothed_t = gr.Slider(minimum=0, maximum=1, step=0.1, label="T", value=0.8)
kd_submit = gr.Button("Generate", elem_id="kdtalker_generate", variant="primary")
kd_submit.click(
fn=generate,
inputs=[
upload_driven_audio, tts_driven_audio, driven_audio_type, source_image,
smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t
],
outputs=[gen_video]
)
with gr.TabItem("Example"):
example_choice = gr.Dropdown(choices=example_choices, label="Choose an example")
def load_example(choice):
example = example_mapping.get(choice, {})
audio_path = example.get("audio", "")
image_path = example.get("image", "")
return [audio_path, image_path]
example_choice.change(
fn=load_example,
inputs=[example_choice],
outputs=[upload_driven_audio, source_image]
)
example_choice.change(set_upload, outputs=driven_audio_type)
return interface
demo = main()
demo.queue().launch(share=True) |