File size: 7,343 Bytes
b0e2648
 
da524b2
b0e2648
 
 
 
6e48b01
6cee1c8
e7775ac
40152ba
be5d017
5b1bdc7
11c707b
4ef033b
e2f7b24
 
 
cd1584f
 
 
 
 
b0e2648
 
2a2d679
 
 
 
 
 
c40d13d
 
b0e2648
b005ca9
 
 
 
 
 
 
 
e2732df
eb2b3c2
b0e2648
d9bc34a
b0e2648
 
 
 
fb797bd
b0e2648
 
 
 
 
 
 
 
e2732df
 
 
 
99167d3
3ac881c
7beb039
b0e2648
 
 
 
 
e88cdba
624cf9d
e88cdba
 
 
 
 
 
 
b0e2648
e88cdba
 
ec8499f
 
b0e2648
 
7beb039
b0e2648
 
7beb039
b0e2648
 
2a2d679
b0e2648
 
2a2d679
b0e2648
2a2d679
b0e2648
 
 
 
 
7beb039
b0e2648
eb2b3c2
2a2d679
b0e2648
 
 
 
 
 
 
 
 
 
e2732df
c40d13d
 
 
b0e2648
c40d13d
 
b005ca9
 
bc465fc
 
 
 
cdb1e49
ec8499f
 
 
 
b005ca9
 
 
bc465fc
b0e2648
 
8b4a444
 
1595e30
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os, sys
import gradio as gr
from difpoint.inference import Inferencer
from TTS.api import TTS
import torch
import time
from flask import send_from_directory
from huggingface_hub import snapshot_download
import spaces
import tensorrt
import multiprocessing as mp
import pickle
mp.set_start_method('spawn', force=True)  


repo_id = "ChaolongYang/KDTalker"
local_dir = "./downloaded_repo"  
snapshot_download(repo_id=repo_id, local_dir=local_dir)
print("\nFiles downloaded:")
for root, dirs, files in os.walk(local_dir):
    for file in files:
        file_path = os.path.join(root, file)
        print(file_path)

result_dir = "results"
def set_upload():
    return "upload"
def set_microphone():
    return "microphone"
def set_tts():
    return "tts"
def create_kd_talker():
    return Inferencer() 

example_folder = "example"
example_choices = ["Example 1", "Example 2", "Example 3"]
example_mapping = {
    "Example 1": {"audio": os.path.join(example_folder, "example1.wav"), "image": os.path.join(example_folder, "example1.png")},
    "Example 2": {"audio": os.path.join(example_folder, "example2.wav"), "image": os.path.join(example_folder, "example2.png")},
    "Example 3": {"audio": os.path.join(example_folder, "example3.wav"), "image": os.path.join(example_folder, "example3.png")},
}

@spaces.GPU
def predict(prompt, upload_reference_audio, microphone_reference_audio, reference_audio_type):
    global result_dir
    output_file_path = os.path.join('./downloaded_repo/', 'output.wav')
    if reference_audio_type == 'upload':
        audio_file_pth = upload_reference_audio
    elif reference_audio_type == 'microphone':
        audio_file_pth =  microphone_reference_audio
    tts = TTS('tts_models/multilingual/multi-dataset/your_tts')
    tts.tts_to_file(
        text=prompt,
        file_path=output_file_path,
        speaker_wav=audio_file_pth,
        language="en",
    )
    return gr.Audio(value=output_file_path, type='filepath')

@spaces.GPU
def generate(upload_driven_audio, tts_driven_audio, driven_audio_type, source_image, smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t):
    return Inferencer().generate_with_audio_img(upload_driven_audio, tts_driven_audio, driven_audio_type, source_image,
                                    smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t)


def main():
    if torch.cuda.is_available():
        device = "cuda" 
    else:
        device = "cpu"
    with gr.Blocks(analytics_enabled=False) as interface:
        with gr.Row():
            gr.HTML(
            """
                <div align='center'>
                    <h2> Unlock Pose Diversity: Accurate and Efficient Implicit Keypoint-based Spatiotemporal Diffusion for Audio-driven Talking Portrait </h2>
                    <div style="display: flex; justify-content: center; align-items: center; gap: 20px;">
                        <img src='https://newstatic.dukekunshan.edu.cn/mainsite/2021/08/07161629/large_dku-Logo-e1649298929570.png' alt='Logo' width='150'/>
                        <img src='https://www.xjtlu.edu.cn/wp-content/uploads/2023/12/7c52fd62e9cf26cb493faa7f91c2782.png' width='250'/>
                    </div>
                </div>
            """
            )
            driven_audio_type = gr.Textbox(value="upload", visible=False) 
            reference_audio_type = gr.Textbox(value="upload", visible=False)
        with gr.Row():
            with gr.Column(variant="panel"):
                with gr.Tabs(elem_id="kdtalker_source_image"):
                    with gr.TabItem("Upload image"):
                        source_image = gr.Image(label="Source image", sources="upload", type="filepath", scale=256)
                with gr.Tabs(elem_id="kdtalker_driven_audio"):
                    with gr.TabItem("Upload"):
                        upload_driven_audio = gr.Audio(label="Upload audio", sources="upload", type="filepath")
                        upload_driven_audio.change(set_upload, outputs=driven_audio_type)
                    with gr.TabItem("TTS"):
                        upload_reference_audio = gr.Audio(label="Upload Reference Audio", sources="upload", type="filepath")
                        upload_reference_audio.change(set_upload, outputs=reference_audio_type)
                        microphone_reference_audio = gr.Audio(label="Recorded Reference Audio", sources="microphone", type="filepath")
                        microphone_reference_audio.change(set_microphone, outputs=reference_audio_type)
                        input_text = gr.Textbox(
                            label="Generating audio from text",
                            lines=5,
                            placeholder="please enter some text here, we generate the audio from text using @Coqui.ai TTS."
                        )
                        tts_button = gr.Button("Generate audio", elem_id="kdtalker_audio_generate", variant="primary")
                        tts_driven_audio = gr.Audio(label="Synthesised Audio", type="filepath")
                        tts_button.click(fn=predict, inputs=[input_text, upload_reference_audio, microphone_reference_audio, reference_audio_type], outputs=[tts_driven_audio])
                        tts_button.click(set_tts, outputs=driven_audio_type)
            with gr.Column(variant="panel"):
                gen_video = gr.Video(label="Generated video", format="mp4", width=256)
                with gr.Tabs(elem_id="talker_checkbox"):
                    with gr.TabItem("KDTalker"):
                        smoothed_pitch = gr.Slider(minimum=0, maximum=1, step=0.1, label="Pitch", value=0.8)
                        smoothed_yaw = gr.Slider(minimum=0, maximum=1, step=0.1, label="Yaw", value=0.8)
                        smoothed_roll = gr.Slider(minimum=0, maximum=1, step=0.1, label="Roll", value=0.8)
                        smoothed_t = gr.Slider(minimum=0, maximum=1, step=0.1, label="T", value=0.8)
                        kd_submit = gr.Button("Generate", elem_id="kdtalker_generate", variant="primary")
                        kd_submit.click(
                                fn=generate,
                                inputs=[
                                    upload_driven_audio, tts_driven_audio, driven_audio_type, source_image,
                                    smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t
                                ],
                                outputs=[gen_video]
                            )
                    with gr.TabItem("Example"):
                        example_choice = gr.Dropdown(choices=example_choices, label="Choose an example")
                        def load_example(choice):
                            example = example_mapping.get(choice, {})
                            audio_path = example.get("audio", "")
                            image_path = example.get("image", "")
                            return [audio_path, image_path]
                        example_choice.change(
                            fn=load_example, 
                            inputs=[example_choice], 
                            outputs=[upload_driven_audio, source_image]
                        )
                        example_choice.change(set_upload, outputs=driven_audio_type)


    return interface


demo = main()
demo.queue().launch(share=True)