test_kdtalker / app.py
YinuoGuo27's picture
Update app.py
cd1584f verified
raw
history blame
6.21 kB
import os, sys
import gradio as gr
#from difpoint.inference import Inferencer
from TTS.api import TTS
import torch
import time
from flask import send_from_directory
from huggingface_hub import snapshot_download
repo_id = "ChaolongYang/KDTalker"
local_dir = "./downloaded_repo"
snapshot_download(repo_id=repo_id, local_dir=local_dir)
print("Download complete! Repository saved at:", local_dir)
print("\nFiles downloaded:")
for root, dirs, files in os.walk(local_dir):
for file in files:
file_path = os.path.join(root, file)
print(file_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)
tts = TTS('tts_models/multilingual/multi-dataset/your_tts').to(device)
tts.to(device)
result_dir = "results"
def predict(prompt, upload_reference_audio, microphone_reference_audio, reference_audio_type):
global result_dir
output_file_path = os.path.join(result_dir, 'output.wav')
if reference_audio_type == 'upload':
audio_file_pth = upload_reference_audio
elif reference_audio_type == 'microphone':
audio_file_pth = microphone_reference_audio
tts.tts_to_file(
text=prompt,
file_path=output_file_path,
speaker_wav=audio_file_pth,
language="en",
)
return gr.Audio(value=output_file_path, type='filepath')
def main(sadtaker_checkpoint_path=r"SadTalker/checkpoints", sadtalker_config_path=r"SadTalker/src/config"):
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
print('device is', device)
torch.set_default_device(device)
tts = TTS('tts_models/multilingual/multi-dataset/your_tts').to(device)
kd_talker = Inferencer()
with gr.Blocks(analytics_enabled=False) as interface:
gr.Markdown(
"""
<div align='center'>
<h2> Unlock Pose Diversity: Accurate and Efficient Implicit Keypoint-based Spatiotemporal Diffusion for Audio-driven Talking Portrait </h2>
<div style="display: flex; justify-content: center; align-items: center; gap: 20px;">
<img src='https://newstatic.dukekunshan.edu.cn/mainsite/2021/08/07161629/large_dku-Logo-e1649298929570.png' alt='Logo' width='150'/>
<img src='https://www.xjtlu.edu.cn/wp-content/uploads/2023/12/7c52fd62e9cf26cb493faa7f91c2782.png' width='250'/>
</div>
</div>
"""
)
driven_audio_type = gr.State(value="upload")
with gr.Row():
with gr.Column(variant="panel"):
with gr.Tabs(elem_id="sadtalker_source_image"):
with gr.TabItem("Upload image"):
source_image = gr.Image(label="Source image", sources="upload", type="filepath", scale=256)
with gr.Tabs(elem_id="sadtalker_driven_audio"):
with gr.TabItem("Upload"):
upload_driven_audio = gr.Audio(label="Upload audio", sources="upload", type="filepath")
upload_driven_audio.change(lambda: "upload", outputs=driven_audio_type)
reference_audio_type = gr.State(value="upload")
with gr.TabItem("TTS"):
upload_reference_audio = gr.Audio(label="Upload Reference Audio", sources="upload", type="filepath")
upload_reference_audio.change(lambda: "upload", outputs=reference_audio_type)
microphone_reference_audio = gr.Audio(label="Recorded Reference Audio", sources="microphone", type="filepath")
microphone_reference_audio.change(lambda: "microphone", outputs=reference_audio_type)
input_text = gr.Textbox(
label="Generating audio from text",
lines=5,
placeholder="please enter some text here, we generate the audio from text using @Coqui.ai TTS."
)
tts_button = gr.Button("Generate audio", elem_id="sadtalker_audio_generate", variant="primary")
tts_driven_audio = gr.Audio(label="Synthesised Audio", type="filepath")
tts_button.click(fn=predict, inputs=[input_text, upload_reference_audio, microphone_reference_audio, reference_audio_type], outputs=[tts_driven_audio])
tts_button.click(lambda: "tts", outputs=driven_audio_type)
with gr.Column(variant="panel"):
gen_video = gr.Video(label="Generated video", format="mp4", width=256)
with gr.Tabs(elem_id="talker_checkbox"):
with gr.TabItem("KDTalker"):
smoothed_pitch = gr.Slider(minimum=0, maximum=1, step=0.1, label="Pitch", value=0.8)
smoothed_yaw = gr.Slider(minimum=0, maximum=1, step=0.1, label="Yaw", value=0.8)
smoothed_roll = gr.Slider(minimum=0, maximum=1, step=0.1, label="Roll", value=0.8)
smoothed_t = gr.Slider(minimum=0, maximum=1, step=0.1, label="T", value=0.8)
kd_submit = gr.Button("Generate", elem_id="kdtalker_generate", variant="primary")
kd_submit.click(
fn=kd_talker.generate_with_audio_img,
inputs=[
upload_driven_audio,
tts_driven_audio,
driven_audio_type,
source_image,
smoothed_pitch,
smoothed_yaw,
smoothed_roll,
smoothed_t,
],
outputs=[gen_video]
)
return interface
if __name__ == "__main__":
os.environ["GRADIO_SERVER_PORT"] = "7860"
demo = main()
#demo.launch(server_name="0.0.0.0",ssl_certfile="cert.pem", ssl_keyfile="key.pem", ssl_verify=False, strict_cors = False)
demo.launch(server_name="0.0.0.0")