import os, sys import gradio as gr from difpoint.inference import Inferencer from TTS.api import TTS import torch import time from flask import send_from_directory from huggingface_hub import snapshot_download import spaces import tensorrt import multiprocessing as mp import pickle mp.set_start_method('spawn', force=True) repo_id = "ChaolongYang/KDTalker" local_dir = "./downloaded_repo" snapshot_download(repo_id=repo_id, local_dir=local_dir) print("\nFiles downloaded:") for root, dirs, files in os.walk(local_dir): for file in files: file_path = os.path.join(root, file) print(file_path) result_dir = "results" def set_upload(): return "upload" def set_microphone(): return "microphone" def set_tts(): return "tts" def create_kd_talker(): return Inferencer() example_folder = "example" example_choices = ["Example 1", "Example 2", "Example 3"] example_mapping = { "Example 1": {"audio": os.path.join(example_folder, "example1.wav"), "image": os.path.join(example_folder, "example1.png")}, "Example 2": {"audio": os.path.join(example_folder, "example2.wav"), "image": os.path.join(example_folder, "example2.png")}, "Example 3": {"audio": os.path.join(example_folder, "example3.wav"), "image": os.path.join(example_folder, "example3.png")}, } @spaces.GPU def predict(prompt, upload_reference_audio, microphone_reference_audio, reference_audio_type): global result_dir output_file_path = os.path.join('./downloaded_repo/', 'output.wav') if reference_audio_type == 'upload': audio_file_pth = upload_reference_audio elif reference_audio_type == 'microphone': audio_file_pth = microphone_reference_audio tts = TTS('tts_models/multilingual/multi-dataset/your_tts') tts.tts_to_file( text=prompt, file_path=output_file_path, speaker_wav=audio_file_pth, language="en", ) return gr.Audio(value=output_file_path, type='filepath') @spaces.GPU def generate(upload_driven_audio, tts_driven_audio, driven_audio_type, source_image, smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t): return Inferencer().generate_with_audio_img(upload_driven_audio, tts_driven_audio, driven_audio_type, source_image, smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t) def main(): if torch.cuda.is_available(): device = "cuda" else: device = "cpu" with gr.Blocks(analytics_enabled=False) as interface: with gr.Row(): gr.HTML( """

Unlock Pose Diversity: Accurate and Efficient Implicit Keypoint-based Spatiotemporal Diffusion for Audio-driven Talking Portrait

Logo
""" ) driven_audio_type = gr.Textbox(value="upload", visible=False) reference_audio_type = gr.Textbox(value="upload", visible=False) with gr.Row(): with gr.Column(variant="panel"): with gr.Tabs(elem_id="kdtalker_source_image"): with gr.TabItem("Upload image"): source_image = gr.Image(label="Source image", sources="upload", type="filepath", scale=256) with gr.Tabs(elem_id="kdtalker_driven_audio"): with gr.TabItem("Upload"): upload_driven_audio = gr.Audio(label="Upload audio", sources="upload", type="filepath") upload_driven_audio.change(set_upload, outputs=driven_audio_type) with gr.TabItem("TTS"): upload_reference_audio = gr.Audio(label="Upload Reference Audio", sources="upload", type="filepath") upload_reference_audio.change(set_upload, outputs=reference_audio_type) microphone_reference_audio = gr.Audio(label="Recorded Reference Audio", sources="microphone", type="filepath") microphone_reference_audio.change(set_microphone, outputs=reference_audio_type) input_text = gr.Textbox( label="Generating audio from text", lines=5, placeholder="please enter some text here, we generate the audio from text using @Coqui.ai TTS." ) tts_button = gr.Button("Generate audio", elem_id="kdtalker_audio_generate", variant="primary") tts_driven_audio = gr.Audio(label="Synthesised Audio", type="filepath") tts_button.click(fn=predict, inputs=[input_text, upload_reference_audio, microphone_reference_audio, reference_audio_type], outputs=[tts_driven_audio]) tts_button.click(set_tts, outputs=driven_audio_type) with gr.Column(variant="panel"): gen_video = gr.Video(label="Generated video", format="mp4", width=256) with gr.Tabs(elem_id="talker_checkbox"): with gr.TabItem("KDTalker"): smoothed_pitch = gr.Slider(minimum=0, maximum=1, step=0.1, label="Pitch", value=0.8) smoothed_yaw = gr.Slider(minimum=0, maximum=1, step=0.1, label="Yaw", value=0.8) smoothed_roll = gr.Slider(minimum=0, maximum=1, step=0.1, label="Roll", value=0.8) smoothed_t = gr.Slider(minimum=0, maximum=1, step=0.1, label="T", value=0.8) kd_submit = gr.Button("Generate", elem_id="kdtalker_generate", variant="primary") kd_submit.click( fn=generate, inputs=[ upload_driven_audio, tts_driven_audio, driven_audio_type, source_image, smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t ], outputs=[gen_video] ) with gr.TabItem("Example"): example_choice = gr.Dropdown(choices=example_choices, label="Choose an example") def load_example(choice): example = example_mapping.get(choice, {}) audio_path = example.get("audio", "") image_path = example.get("image", "") return [audio_path, image_path] example_choice.change( fn=load_example, inputs=[example_choice], outputs=[upload_driven_audio, source_image] ) example_choice.change(set_upload, outputs=driven_audio_type) return interface demo = main() demo.queue().launch(share=True)