Spaces:
Sleeping
Sleeping
File size: 6,074 Bytes
b0e2648 da524b2 b0e2648 6e48b01 6cee1c8 e2f7b24 541a02e e2f7b24 cd1584f b0e2648 51facdc 4615c5e b0e2648 51facdc 7beb039 b0e2648 4615c5e b0e2648 4615c5e b0e2648 4615c5e b0e2648 7beb039 b0e2648 7beb039 b0e2648 7beb039 b0e2648 4615c5e b0e2648 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import os, sys
import gradio as gr
from difpoint.inference import Inferencer
from TTS.api import TTS
import torch
import time
from flask import send_from_directory
from huggingface_hub import snapshot_download
import spaces
repo_id = "ChaolongYang/KDTalker"
local_dir = "./downloaded_repo"
snapshot_download(repo_id=repo_id, local_dir=local_dir)
print("Download complete! Repository saved at:", local_dir)
print("\nFiles downloaded:")
for root, dirs, files in os.walk(local_dir):
for file in files:
file_path = os.path.join(root, file)
print(file_path)
result_dir = "results"
@spaces.GPU
def predict(prompt, upload_reference_audio, microphone_reference_audio, reference_audio_type, tts):
global result_dir
output_file_path = os.path.join(result_dir, 'output.wav')
if reference_audio_type == 'upload':
audio_file_pth = upload_reference_audio
elif reference_audio_type == 'microphone':
audio_file_pth = microphone_reference_audio
tts.tts_to_file(
text=prompt,
file_path=output_file_path,
speaker_wav=audio_file_pth,
language="en",
)
return gr.Audio(value=output_file_path, type='filepath')
@spaces.GPU
def main():
if torch.cuda.is_available():
device = "cuda"
print("cuda available")
else:
device = "cpu"
print('device is', device)
torch.set_default_device(device)
torch.set_default_device(device)
tts = TTS('tts_models/multilingual/multi-dataset/your_tts').to(device)
tts.to(device)
kd_talker = Inferencer()
with gr.Blocks(analytics_enabled=False) as interface:
gr.Markdown(
"""
<div align='center'>
<h2> Unlock Pose Diversity: Accurate and Efficient Implicit Keypoint-based Spatiotemporal Diffusion for Audio-driven Talking Portrait </h2>
<div style="display: flex; justify-content: center; align-items: center; gap: 20px;">
<img src='https://newstatic.dukekunshan.edu.cn/mainsite/2021/08/07161629/large_dku-Logo-e1649298929570.png' alt='Logo' width='150'/>
<img src='https://www.xjtlu.edu.cn/wp-content/uploads/2023/12/7c52fd62e9cf26cb493faa7f91c2782.png' width='250'/>
</div>
</div>
"""
)
driven_audio_type = gr.State(value="upload")
with gr.Row():
with gr.Column(variant="panel"):
with gr.Tabs(elem_id="kdtalker_source_image"):
with gr.TabItem("Upload image"):
source_image = gr.Image(label="Source image", sources="upload", type="filepath", scale=256)
with gr.Tabs(elem_id="kdtalker_driven_audio"):
with gr.TabItem("Upload"):
upload_driven_audio = gr.Audio(label="Upload audio", sources="upload", type="filepath")
upload_driven_audio.change(lambda: "upload", outputs=driven_audio_type)
reference_audio_type = gr.State(value="upload")
with gr.TabItem("TTS"):
upload_reference_audio = gr.Audio(label="Upload Reference Audio", sources="upload", type="filepath")
upload_reference_audio.change(lambda: "upload", outputs=reference_audio_type)
microphone_reference_audio = gr.Audio(label="Recorded Reference Audio", sources="microphone", type="filepath")
microphone_reference_audio.change(lambda: "microphone", outputs=reference_audio_type)
input_text = gr.Textbox(
label="Generating audio from text",
lines=5,
placeholder="please enter some text here, we generate the audio from text using @Coqui.ai TTS."
)
tts_button = gr.Button("Generate audio", elem_id="kdtalker_audio_generate", variant="primary")
tts_driven_audio = gr.Audio(label="Synthesised Audio", type="filepath")
tts_button.click(fn=predict, inputs=[input_text, upload_reference_audio, microphone_reference_audio, reference_audio_type, tts], outputs=[tts_driven_audio])
tts_button.click(lambda: "tts", outputs=driven_audio_type)
with gr.Column(variant="panel"):
gen_video = gr.Video(label="Generated video", format="mp4", width=256)
with gr.Tabs(elem_id="talker_checkbox"):
with gr.TabItem("KDTalker"):
smoothed_pitch = gr.Slider(minimum=0, maximum=1, step=0.1, label="Pitch", value=0.8)
smoothed_yaw = gr.Slider(minimum=0, maximum=1, step=0.1, label="Yaw", value=0.8)
smoothed_roll = gr.Slider(minimum=0, maximum=1, step=0.1, label="Roll", value=0.8)
smoothed_t = gr.Slider(minimum=0, maximum=1, step=0.1, label="T", value=0.8)
kd_submit = gr.Button("Generate", elem_id="kdtalker_generate", variant="primary")
kd_submit.click(
fn=kd_talker.generate_with_audio_img,
inputs=[
upload_driven_audio,
tts_driven_audio,
driven_audio_type,
source_image,
smoothed_pitch,
smoothed_yaw,
smoothed_roll,
smoothed_t,
],
outputs=[gen_video]
)
return interface
if __name__ == "__main__":
os.environ["GRADIO_SERVER_PORT"] = "7860"
demo = main()
#demo.launch(server_name="0.0.0.0",ssl_certfile="cert.pem", ssl_keyfile="key.pem", ssl_verify=False, strict_cors = False)
demo.launch(server_name="0.0.0.0")
|