File size: 5,969 Bytes
b0e2648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os, sys
import gradio as gr
from SadTalker.src.gradio_demo import SadTalker
from difpoint.inference import Inferencer
from TTS.api import TTS
import torch
import time
from flask import send_from_directory

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)
tts = TTS('tts_models/multilingual/multi-dataset/your_tts').to(device)
tts.to(device)

result_dir = "results"

def predict(prompt, upload_reference_audio, microphone_reference_audio, reference_audio_type):
    global result_dir
    output_file_path = os.path.join(result_dir, 'output.wav')
    if reference_audio_type == 'upload':
        audio_file_pth = upload_reference_audio
    elif reference_audio_type == 'microphone':
        audio_file_pth =  microphone_reference_audio
    tts.tts_to_file(
        text=prompt,
        file_path=output_file_path,
        speaker_wav=audio_file_pth,
        language="en",
    )
    return gr.Audio(value=output_file_path, type='filepath')


def main(sadtaker_checkpoint_path=r"SadTalker/checkpoints", sadtalker_config_path=r"SadTalker/src/config"):
    if torch.cuda.is_available():
        device = "cuda" 
    else:
        device = "cpu"
    print('device is', device)
    torch.set_default_device(device)
    tts = TTS('tts_models/multilingual/multi-dataset/your_tts').to(device)
    sad_talker = SadTalker(sadtaker_checkpoint_path, sadtalker_config_path, lazy_load=True)
    kd_talker = Inferencer()

    with gr.Blocks(analytics_enabled=False) as interface:
        gr.Markdown(
        """
            <div align='center'>
                <h2> Unlock Pose Diversity: Accurate and Efficient Implicit Keypoint-based Spatiotemporal Diffusion for Audio-driven Talking Portrait </h2>
                <div style="display: flex; justify-content: center; align-items: center; gap: 20px;">
                    <img src='https://newstatic.dukekunshan.edu.cn/mainsite/2021/08/07161629/large_dku-Logo-e1649298929570.png' alt='Logo' width='150'/>
                    <img src='https://www.xjtlu.edu.cn/wp-content/uploads/2023/12/7c52fd62e9cf26cb493faa7f91c2782.png' width='250'/>
                </div>
            </div>
        """
        )
        driven_audio_type = gr.State(value="upload")  

        with gr.Row():
            with gr.Column(variant="panel"):
                with gr.Tabs(elem_id="sadtalker_source_image"):
                    with gr.TabItem("Upload image"):
                        source_image = gr.Image(label="Source image", sources="upload", type="filepath", scale=256)
                
                with gr.Tabs(elem_id="sadtalker_driven_audio"):
                    with gr.TabItem("Upload"):
                        upload_driven_audio = gr.Audio(label="Upload audio", sources="upload", type="filepath")
                        upload_driven_audio.change(lambda: "upload", outputs=driven_audio_type)
                    reference_audio_type = gr.State(value="upload")

                    with gr.TabItem("TTS"):
                        upload_reference_audio = gr.Audio(label="Upload Reference Audio", sources="upload", type="filepath")
                        upload_reference_audio.change(lambda: "upload", outputs=reference_audio_type)
                        microphone_reference_audio = gr.Audio(label="Recorded Reference Audio", sources="microphone", type="filepath")
                        microphone_reference_audio.change(lambda: "microphone", outputs=reference_audio_type)
                        input_text = gr.Textbox(
                            label="Generating audio from text",
                            lines=5,
                            placeholder="please enter some text here, we generate the audio from text using @Coqui.ai TTS."
                        )
                        tts_button = gr.Button("Generate audio", elem_id="sadtalker_audio_generate", variant="primary")
                        tts_driven_audio = gr.Audio(label="Synthesised Audio", type="filepath")
                        tts_button.click(fn=predict, inputs=[input_text, upload_reference_audio, microphone_reference_audio, reference_audio_type], outputs=[tts_driven_audio])
                        tts_button.click(lambda: "tts", outputs=driven_audio_type)

            with gr.Column(variant="panel"):
                gen_video = gr.Video(label="Generated video", format="mp4", width=256)
                with gr.Tabs(elem_id="talker_checkbox"):
                    with gr.TabItem("KDTalker"):
                        smoothed_pitch = gr.Slider(minimum=0, maximum=1, step=0.1, label="Pitch", value=0.8)
                        smoothed_yaw = gr.Slider(minimum=0, maximum=1, step=0.1, label="Yaw", value=0.8)
                        smoothed_roll = gr.Slider(minimum=0, maximum=1, step=0.1, label="Roll", value=0.8)
                        smoothed_t = gr.Slider(minimum=0, maximum=1, step=0.1, label="T", value=0.8)
                        kd_submit = gr.Button("Generate", elem_id="kdtalker_generate", variant="primary")

                        kd_submit.click(
                            fn=kd_talker.generate_with_audio_img,
                            inputs=[
                                    upload_driven_audio,
                                    tts_driven_audio,
                                    driven_audio_type,
                                    source_image,
                                    smoothed_pitch,
                                    smoothed_yaw,
                                    smoothed_roll,
                                    smoothed_t,
                                ],
                            outputs=[gen_video]
                        )
    return interface

if __name__ == "__main__":
    os.environ["GRADIO_SERVER_PORT"] = "7860"
    demo = main()
    #demo.launch(server_name="0.0.0.0",ssl_certfile="cert.pem", ssl_keyfile="key.pem", ssl_verify=False, strict_cors = False)
    demo.launch(server_name="0.0.0.0")