fffiloni commited on
Commit
3b88a56
·
verified ·
1 Parent(s): 3cbfa4c

Update gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +38 -12
gradio_app.py CHANGED
@@ -3,6 +3,14 @@ import os
3
  os.environ['HYDRA_FULL_ERROR']='1'
4
  os.environ['CUDA_VISIBLE_DEVICES'] = '0'
5
 
 
 
 
 
 
 
 
 
6
  import argparse
7
  import shutil
8
  import uuid
@@ -13,7 +21,7 @@ import cv2
13
  from rich.progress import track
14
  import tyro
15
 
16
-
17
  from PIL import Image
18
  import time
19
  import torch
@@ -133,7 +141,7 @@ class Inferencer(object):
133
 
134
  from model import get_model
135
  self.point_diffusion = get_model()
136
- ckpt = torch.load('KDTalker.pth')
137
 
138
  self.point_diffusion.load_state_dict(ckpt['model'])
139
  self.point_diffusion.eval()
@@ -368,16 +376,34 @@ class Inferencer(object):
368
  os.remove(path)
369
  os.remove(new_audio_path)
370
 
 
371
 
372
- if __name__ == '__main__':
373
- parser = argparse.ArgumentParser()
374
- parser.add_argument("-source_image", type=str, default="example/source_image/WDA_BenCardin1_000.png",
375
- help="source image")
376
- parser.add_argument("-driven_audio", type=str, default="example/driven_audio/WDA_BenCardin1_000.wav",
377
- help="driving audio")
378
- parser.add_argument("-output", type=str, default="results/output.mp4", help="output video file name", )
379
-
380
- args = parser.parse_args()
381
 
382
  Infer = Inferencer()
383
- Infer.generate_with_audio_img(args.source_image, args.driven_audio, args.output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  os.environ['HYDRA_FULL_ERROR']='1'
4
  os.environ['CUDA_VISIBLE_DEVICES'] = '0'
5
 
6
+ from huggingface_hub import snapshot_download
7
+
8
+ # Download weights
9
+ snapshot_download(
10
+ repo_id = "ChaolongYang/KDTalker",
11
+ local_dir = "./"
12
+ )
13
+
14
  import argparse
15
  import shutil
16
  import uuid
 
21
  from rich.progress import track
22
  import tyro
23
 
24
+ import gradio as gr
25
  from PIL import Image
26
  import time
27
  import torch
 
141
 
142
  from model import get_model
143
  self.point_diffusion = get_model()
144
+ ckpt = torch.load('./KDTalker.pth')
145
 
146
  self.point_diffusion.load_state_dict(ckpt['model'])
147
  self.point_diffusion.eval()
 
376
  os.remove(path)
377
  os.remove(new_audio_path)
378
 
379
+ def gradio_infer(source_image, driven_audio):
380
 
381
+ output_path = "results/output.mp4"
 
 
 
 
 
 
 
 
382
 
383
  Infer = Inferencer()
384
+ Infer.generate_with_audio_img(source_image, driven_audio, output)
385
+
386
+ return output_path
387
+
388
+ with gr.Blocks() as demo:
389
+ with gr.Column():
390
+ gr.Markdown("# KDTalker")
391
+
392
+ with gr.Row():
393
+
394
+ with gr.Column():
395
+ source_image = gr.Image(label="Source Image", type="filepath")
396
+ driven_audio = gr.Audio(label="Driven Audio")
397
+ submit_btn = gr.Button("Submit")
398
+
399
+ with gr.Column():
400
+ output_video = gr.Video(label="Output Video")
401
+
402
+ submit_btn.click(
403
+ fn = gradio_infer,
404
+ inputs = [source_image, driven_audio],
405
+ outputs = [output_video]
406
+ )
407
+
408
+ demo.launch()
409
+