Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,343 Bytes
b0e2648 da524b2 b0e2648 6e48b01 6cee1c8 e7775ac 40152ba be5d017 5b1bdc7 11c707b 4ef033b e2f7b24 cd1584f b0e2648 2a2d679 c40d13d b0e2648 b005ca9 e2732df eb2b3c2 b0e2648 d9bc34a b0e2648 fb797bd b0e2648 e2732df 99167d3 3ac881c 7beb039 b0e2648 e88cdba 624cf9d e88cdba b0e2648 e88cdba ec8499f b0e2648 7beb039 b0e2648 7beb039 b0e2648 2a2d679 b0e2648 2a2d679 b0e2648 2a2d679 b0e2648 7beb039 b0e2648 eb2b3c2 2a2d679 b0e2648 e2732df c40d13d b0e2648 c40d13d b005ca9 bc465fc cdb1e49 ec8499f b005ca9 bc465fc b0e2648 8b4a444 1595e30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os, sys
import gradio as gr
from difpoint.inference import Inferencer
from TTS.api import TTS
import torch
import time
from flask import send_from_directory
from huggingface_hub import snapshot_download
import spaces
import tensorrt
import multiprocessing as mp
import pickle
mp.set_start_method('spawn', force=True)
repo_id = "ChaolongYang/KDTalker"
local_dir = "./downloaded_repo"
snapshot_download(repo_id=repo_id, local_dir=local_dir)
print("\nFiles downloaded:")
for root, dirs, files in os.walk(local_dir):
for file in files:
file_path = os.path.join(root, file)
print(file_path)
result_dir = "results"
def set_upload():
return "upload"
def set_microphone():
return "microphone"
def set_tts():
return "tts"
def create_kd_talker():
return Inferencer()
example_folder = "example"
example_choices = ["Example 1", "Example 2", "Example 3"]
example_mapping = {
"Example 1": {"audio": os.path.join(example_folder, "example1.wav"), "image": os.path.join(example_folder, "example1.png")},
"Example 2": {"audio": os.path.join(example_folder, "example2.wav"), "image": os.path.join(example_folder, "example2.png")},
"Example 3": {"audio": os.path.join(example_folder, "example3.wav"), "image": os.path.join(example_folder, "example3.png")},
}
@spaces.GPU
def predict(prompt, upload_reference_audio, microphone_reference_audio, reference_audio_type):
global result_dir
output_file_path = os.path.join('./downloaded_repo/', 'output.wav')
if reference_audio_type == 'upload':
audio_file_pth = upload_reference_audio
elif reference_audio_type == 'microphone':
audio_file_pth = microphone_reference_audio
tts = TTS('tts_models/multilingual/multi-dataset/your_tts')
tts.tts_to_file(
text=prompt,
file_path=output_file_path,
speaker_wav=audio_file_pth,
language="en",
)
return gr.Audio(value=output_file_path, type='filepath')
@spaces.GPU
def generate(upload_driven_audio, tts_driven_audio, driven_audio_type, source_image, smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t):
return Inferencer().generate_with_audio_img(upload_driven_audio, tts_driven_audio, driven_audio_type, source_image,
smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t)
def main():
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
with gr.Blocks(analytics_enabled=False) as interface:
with gr.Row():
gr.HTML(
"""
<div align='center'>
<h2> Unlock Pose Diversity: Accurate and Efficient Implicit Keypoint-based Spatiotemporal Diffusion for Audio-driven Talking Portrait </h2>
<div style="display: flex; justify-content: center; align-items: center; gap: 20px;">
<img src='https://newstatic.dukekunshan.edu.cn/mainsite/2021/08/07161629/large_dku-Logo-e1649298929570.png' alt='Logo' width='150'/>
<img src='https://www.xjtlu.edu.cn/wp-content/uploads/2023/12/7c52fd62e9cf26cb493faa7f91c2782.png' width='250'/>
</div>
</div>
"""
)
driven_audio_type = gr.Textbox(value="upload", visible=False)
reference_audio_type = gr.Textbox(value="upload", visible=False)
with gr.Row():
with gr.Column(variant="panel"):
with gr.Tabs(elem_id="kdtalker_source_image"):
with gr.TabItem("Upload image"):
source_image = gr.Image(label="Source image", sources="upload", type="filepath", scale=256)
with gr.Tabs(elem_id="kdtalker_driven_audio"):
with gr.TabItem("Upload"):
upload_driven_audio = gr.Audio(label="Upload audio", sources="upload", type="filepath")
upload_driven_audio.change(set_upload, outputs=driven_audio_type)
with gr.TabItem("TTS"):
upload_reference_audio = gr.Audio(label="Upload Reference Audio", sources="upload", type="filepath")
upload_reference_audio.change(set_upload, outputs=reference_audio_type)
microphone_reference_audio = gr.Audio(label="Recorded Reference Audio", sources="microphone", type="filepath")
microphone_reference_audio.change(set_microphone, outputs=reference_audio_type)
input_text = gr.Textbox(
label="Generating audio from text",
lines=5,
placeholder="please enter some text here, we generate the audio from text using @Coqui.ai TTS."
)
tts_button = gr.Button("Generate audio", elem_id="kdtalker_audio_generate", variant="primary")
tts_driven_audio = gr.Audio(label="Synthesised Audio", type="filepath")
tts_button.click(fn=predict, inputs=[input_text, upload_reference_audio, microphone_reference_audio, reference_audio_type], outputs=[tts_driven_audio])
tts_button.click(set_tts, outputs=driven_audio_type)
with gr.Column(variant="panel"):
gen_video = gr.Video(label="Generated video", format="mp4", width=256)
with gr.Tabs(elem_id="talker_checkbox"):
with gr.TabItem("KDTalker"):
smoothed_pitch = gr.Slider(minimum=0, maximum=1, step=0.1, label="Pitch", value=0.8)
smoothed_yaw = gr.Slider(minimum=0, maximum=1, step=0.1, label="Yaw", value=0.8)
smoothed_roll = gr.Slider(minimum=0, maximum=1, step=0.1, label="Roll", value=0.8)
smoothed_t = gr.Slider(minimum=0, maximum=1, step=0.1, label="T", value=0.8)
kd_submit = gr.Button("Generate", elem_id="kdtalker_generate", variant="primary")
kd_submit.click(
fn=generate,
inputs=[
upload_driven_audio, tts_driven_audio, driven_audio_type, source_image,
smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t
],
outputs=[gen_video]
)
with gr.TabItem("Example"):
example_choice = gr.Dropdown(choices=example_choices, label="Choose an example")
def load_example(choice):
example = example_mapping.get(choice, {})
audio_path = example.get("audio", "")
image_path = example.get("image", "")
return [audio_path, image_path]
example_choice.change(
fn=load_example,
inputs=[example_choice],
outputs=[upload_driven_audio, source_image]
)
example_choice.change(set_upload, outputs=driven_audio_type)
return interface
demo = main()
demo.queue().launch(share=True) |