Spaces:
Running
on
Zero
Running
on
Zero
Update gradio_app.py
Browse files- gradio_app.py +38 -12
gradio_app.py
CHANGED
@@ -3,6 +3,14 @@ import os
|
|
3 |
os.environ['HYDRA_FULL_ERROR']='1'
|
4 |
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import argparse
|
7 |
import shutil
|
8 |
import uuid
|
@@ -13,7 +21,7 @@ import cv2
|
|
13 |
from rich.progress import track
|
14 |
import tyro
|
15 |
|
16 |
-
|
17 |
from PIL import Image
|
18 |
import time
|
19 |
import torch
|
@@ -133,7 +141,7 @@ class Inferencer(object):
|
|
133 |
|
134 |
from model import get_model
|
135 |
self.point_diffusion = get_model()
|
136 |
-
ckpt = torch.load('KDTalker.pth')
|
137 |
|
138 |
self.point_diffusion.load_state_dict(ckpt['model'])
|
139 |
self.point_diffusion.eval()
|
@@ -368,16 +376,34 @@ class Inferencer(object):
|
|
368 |
os.remove(path)
|
369 |
os.remove(new_audio_path)
|
370 |
|
|
|
371 |
|
372 |
-
|
373 |
-
parser = argparse.ArgumentParser()
|
374 |
-
parser.add_argument("-source_image", type=str, default="example/source_image/WDA_BenCardin1_000.png",
|
375 |
-
help="source image")
|
376 |
-
parser.add_argument("-driven_audio", type=str, default="example/driven_audio/WDA_BenCardin1_000.wav",
|
377 |
-
help="driving audio")
|
378 |
-
parser.add_argument("-output", type=str, default="results/output.mp4", help="output video file name", )
|
379 |
-
|
380 |
-
args = parser.parse_args()
|
381 |
|
382 |
Infer = Inferencer()
|
383 |
-
Infer.generate_with_audio_img(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
os.environ['HYDRA_FULL_ERROR']='1'
|
4 |
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
5 |
|
6 |
+
from huggingface_hub import snapshot_download
|
7 |
+
|
8 |
+
# Download weights
|
9 |
+
snapshot_download(
|
10 |
+
repo_id = "ChaolongYang/KDTalker",
|
11 |
+
local_dir = "./"
|
12 |
+
)
|
13 |
+
|
14 |
import argparse
|
15 |
import shutil
|
16 |
import uuid
|
|
|
21 |
from rich.progress import track
|
22 |
import tyro
|
23 |
|
24 |
+
import gradio as gr
|
25 |
from PIL import Image
|
26 |
import time
|
27 |
import torch
|
|
|
141 |
|
142 |
from model import get_model
|
143 |
self.point_diffusion = get_model()
|
144 |
+
ckpt = torch.load('./KDTalker.pth')
|
145 |
|
146 |
self.point_diffusion.load_state_dict(ckpt['model'])
|
147 |
self.point_diffusion.eval()
|
|
|
376 |
os.remove(path)
|
377 |
os.remove(new_audio_path)
|
378 |
|
379 |
+
def gradio_infer(source_image, driven_audio):
|
380 |
|
381 |
+
output_path = "results/output.mp4"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
|
383 |
Infer = Inferencer()
|
384 |
+
Infer.generate_with_audio_img(source_image, driven_audio, output)
|
385 |
+
|
386 |
+
return output_path
|
387 |
+
|
388 |
+
with gr.Blocks() as demo:
|
389 |
+
with gr.Column():
|
390 |
+
gr.Markdown("# KDTalker")
|
391 |
+
|
392 |
+
with gr.Row():
|
393 |
+
|
394 |
+
with gr.Column():
|
395 |
+
source_image = gr.Image(label="Source Image", type="filepath")
|
396 |
+
driven_audio = gr.Audio(label="Driven Audio")
|
397 |
+
submit_btn = gr.Button("Submit")
|
398 |
+
|
399 |
+
with gr.Column():
|
400 |
+
output_video = gr.Video(label="Output Video")
|
401 |
+
|
402 |
+
submit_btn.click(
|
403 |
+
fn = gradio_infer,
|
404 |
+
inputs = [source_image, driven_audio],
|
405 |
+
outputs = [output_video]
|
406 |
+
)
|
407 |
+
|
408 |
+
demo.launch()
|
409 |
+
|