test_kdtalker / app.py
YinuoGuo27's picture
Update app.py
da524b2 verified
raw
history blame
6.21 kB
import os, sys
import gradio as gr
from difpoint.inference import Inferencer
from TTS.api import TTS
import torch
import time
from flask import send_from_directory
from huggingface_hub import snapshot_download
repo_id = "ChaolongYang/KDTalker"
local_dir = "./downloaded_repo"
snapshot_download(repo_id=repo_id, local_dir=local_dir)
print("Download complete! Repository saved at:", local_dir)
print("\nFiles downloaded:")
for root, dirs, files in os.walk(local_dir):
for file in files:
file_path = os.path.join(root, file)
print(file_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)
tts = TTS('tts_models/multilingual/multi-dataset/your_tts').to(device)
tts.to(device)
result_dir = "results"
def predict(prompt, upload_reference_audio, microphone_reference_audio, reference_audio_type):
global result_dir
output_file_path = os.path.join(result_dir, 'output.wav')
if reference_audio_type == 'upload':
audio_file_pth = upload_reference_audio
elif reference_audio_type == 'microphone':
audio_file_pth = microphone_reference_audio
tts.tts_to_file(
text=prompt,
file_path=output_file_path,
speaker_wav=audio_file_pth,
language="en",
)
return gr.Audio(value=output_file_path, type='filepath')
def main(sadtaker_checkpoint_path=r"SadTalker/checkpoints", sadtalker_config_path=r"SadTalker/src/config"):
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print('device is', device)
torch.set_default_device(device)
tts = TTS('tts_models/multilingual/multi-dataset/your_tts').to(device)
kd_talker = Inferencer()
with gr.Blocks(analytics_enabled=False) as interface:
gr.Markdown(
"""
<div align='center'>
<h2> Unlock Pose Diversity: Accurate and Efficient Implicit Keypoint-based Spatiotemporal Diffusion for Audio-driven Talking Portrait </h2>
<div style="display: flex; justify-content: center; align-items: center; gap: 20px;">
<img src='https://newstatic.dukekunshan.edu.cn/mainsite/2021/08/07161629/large_dku-Logo-e1649298929570.png' alt='Logo' width='150'/>
<img src='https://www.xjtlu.edu.cn/wp-content/uploads/2023/12/7c52fd62e9cf26cb493faa7f91c2782.png' width='250'/>
</div>
</div>
"""
)
driven_audio_type = gr.State(value="upload")
with gr.Row():
with gr.Column(variant="panel"):
with gr.Tabs(elem_id="sadtalker_source_image"):
with gr.TabItem("Upload image"):
source_image = gr.Image(label="Source image", sources="upload", type="filepath", scale=256)
with gr.Tabs(elem_id="sadtalker_driven_audio"):
with gr.TabItem("Upload"):
upload_driven_audio = gr.Audio(label="Upload audio", sources="upload", type="filepath")
upload_driven_audio.change(lambda: "upload", outputs=driven_audio_type)
reference_audio_type = gr.State(value="upload")
with gr.TabItem("TTS"):
upload_reference_audio = gr.Audio(label="Upload Reference Audio", sources="upload", type="filepath")
upload_reference_audio.change(lambda: "upload", outputs=reference_audio_type)
microphone_reference_audio = gr.Audio(label="Recorded Reference Audio", sources="microphone", type="filepath")
microphone_reference_audio.change(lambda: "microphone", outputs=reference_audio_type)
input_text = gr.Textbox(
label="Generating audio from text",
lines=5,
placeholder="please enter some text here, we generate the audio from text using @Coqui.ai TTS."
)
tts_button = gr.Button("Generate audio", elem_id="sadtalker_audio_generate", variant="primary")
tts_driven_audio = gr.Audio(label="Synthesised Audio", type="filepath")
tts_button.click(fn=predict, inputs=[input_text, upload_reference_audio, microphone_reference_audio, reference_audio_type], outputs=[tts_driven_audio])
tts_button.click(lambda: "tts", outputs=driven_audio_type)
with gr.Column(variant="panel"):
gen_video = gr.Video(label="Generated video", format="mp4", width=256)
with gr.Tabs(elem_id="talker_checkbox"):
with gr.TabItem("KDTalker"):
smoothed_pitch = gr.Slider(minimum=0, maximum=1, step=0.1, label="Pitch", value=0.8)
smoothed_yaw = gr.Slider(minimum=0, maximum=1, step=0.1, label="Yaw", value=0.8)
smoothed_roll = gr.Slider(minimum=0, maximum=1, step=0.1, label="Roll", value=0.8)
smoothed_t = gr.Slider(minimum=0, maximum=1, step=0.1, label="T", value=0.8)
kd_submit = gr.Button("Generate", elem_id="kdtalker_generate", variant="primary")
kd_submit.click(
fn=kd_talker.generate_with_audio_img,
inputs=[
upload_driven_audio,
tts_driven_audio,
driven_audio_type,
source_image,
smoothed_pitch,
smoothed_yaw,
smoothed_roll,
smoothed_t,
],
outputs=[gen_video]
)
return interface
if __name__ == "__main__":
os.environ["GRADIO_SERVER_PORT"] = "7860"
demo = main()
#demo.launch(server_name="0.0.0.0",ssl_certfile="cert.pem", ssl_keyfile="key.pem", ssl_verify=False, strict_cors = False)
demo.launch(server_name="0.0.0.0")