test_kdtalker / app.py
YinuoGuo27's picture
Update app.py
62e754e verified
raw
history blame
5.8 kB
import os, sys
import gradio as gr
from difpoint.inference import Inferencer
from TTS.api import TTS
import torch
import time
from flask import send_from_directory
from huggingface_hub import snapshot_download
import spaces
import tensorrt
import multiprocessing as mp
import pickle
mp.set_start_method('spawn', force=True)
repo_id = "ChaolongYang/KDTalker"
local_dir = "./downloaded_repo"
snapshot_download(repo_id=repo_id, local_dir=local_dir)
print("\nFiles downloaded:")
for root, dirs, files in os.walk(local_dir):
for file in files:
file_path = os.path.join(root, file)
print(file_path)
result_dir = "results"
def set_upload():
return "upload"
def set_microphone():
return "microphone"
def set_tts():
return "tts"
def create_kd_talker():
return Inferencer()
@spaces.GPU
def predict(prompt, upload_reference_audio, microphone_reference_audio, reference_audio_type):
global result_dir
output_file_path = os.path.join(result_dir, 'output.wav')
if reference_audio_type == 'upload':
audio_file_pth = upload_reference_audio
elif reference_audio_type == 'microphone':
audio_file_pth = microphone_reference_audio
torch.set_default_device('cuda')
tts = TTS('tts_models/multilingual/multi-dataset/your_tts').to('cuda')
tts.tts_to_file(
text=prompt,
file_path=output_file_path,
speaker_wav=audio_file_pth,
language="en",
)
return gr.Audio(value=output_file_path, type='filepath')
@spaces.GPU
def main():
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
with gr.Blocks(analytics_enabled=False) as interface:
gr.Markdown(
"""
<div align='center'>
<h2> Unlock Pose Diversity: Accurate and Efficient Implicit Keypoint-based Spatiotemporal Diffusion for Audio-driven Talking Portrait </h2>
<div style="display: flex; justify-content: center; align-items: center; gap: 20px;">
<img src='https://newstatic.dukekunshan.edu.cn/mainsite/2021/08/07161629/large_dku-Logo-e1649298929570.png' alt='Logo' width='150'/>
<img src='https://www.xjtlu.edu.cn/wp-content/uploads/2023/12/7c52fd62e9cf26cb493faa7f91c2782.png' width='250'/>
</div>
</div>
"""
)
driven_audio_type = gr.Textbox(value="upload", visible=False)
with gr.Row():
with gr.Column(variant="panel"):
with gr.Tabs(elem_id="kdtalker_source_image"):
with gr.TabItem("Upload image"):
source_image = gr.Image(label="Source image", sources="upload", type="filepath", scale=256)
with gr.Tabs(elem_id="kdtalker_driven_audio"):
with gr.TabItem("Upload"):
upload_driven_audio = gr.Audio(label="Upload audio", sources="upload", type="filepath")
upload_driven_audio.change(set_upload, outputs=driven_audio_type)
reference_audio_type = gr.Textbox(value="upload", visible=False)
with gr.TabItem("TTS"):
upload_reference_audio = gr.Audio(label="Upload Reference Audio", sources="upload", type="filepath")
upload_reference_audio.change(set_upload, outputs=reference_audio_type)
microphone_reference_audio = gr.Audio(label="Recorded Reference Audio", sources="microphone", type="filepath")
microphone_reference_audio.change(set_microphone, outputs=reference_audio_type)
input_text = gr.Textbox(
label="Generating audio from text",
lines=5,
placeholder="please enter some text here, we generate the audio from text using @Coqui.ai TTS."
)
tts_button = gr.Button("Generate audio", elem_id="kdtalker_audio_generate", variant="primary")
tts_driven_audio = gr.Audio(label="Synthesised Audio", type="filepath")
tts_button.click(fn=predict, inputs=[input_text, upload_reference_audio, microphone_reference_audio, reference_audio_type], outputs=[tts_driven_audio])
tts_button.click(set_tts, outputs=driven_audio_type)
with gr.Column(variant="panel"):
gen_video = gr.Video(label="Generated video", format="mp4", width=256)
with gr.Tabs(elem_id="talker_checkbox"):
with gr.TabItem("KDTalker"):
smoothed_pitch = gr.Slider(minimum=0, maximum=1, step=0.1, label="Pitch", value=0.8)
smoothed_yaw = gr.Slider(minimum=0, maximum=1, step=0.1, label="Yaw", value=0.8)
smoothed_roll = gr.Slider(minimum=0, maximum=1, step=0.1, label="Roll", value=0.8)
smoothed_t = gr.Slider(minimum=0, maximum=1, step=0.1, label="T", value=0.8)
kd_submit = gr.Button("Generate", elem_id="kdtalker_generate", variant="primary")
kd_submit.click(
fn=Inferencer().generate_with_audio_img,
inputs=[
upload_driven_audio, tts_driven_audio, driven_audio_type, source_image,
smoothed_pitch, smoothed_yaw, smoothed_roll, smoothed_t
],
outputs=[gen_video]
)
return interface
if __name__ == "__main__":
os.environ["GRADIO_SERVER_PORT"] = "7860"
demo = main()
demo.launch(server_name="0.0.0.0")