File size: 3,431 Bytes
872c724
 
 
 
6a2e37d
872c724
6a2e37d
872c724
 
 
 
 
 
6a2e37d
872c724
 
 
 
6a2e37d
 
 
 
 
 
872c724
6a2e37d
872c724
 
6a2e37d
872c724
 
 
 
 
 
 
 
 
 
 
 
6a2e37d
872c724
6a2e37d
872c724
6a2e37d
872c724
6a2e37d
 
 
 
872c724
 
6a2e37d
 
872c724
6a2e37d
 
 
872c724
 
6a2e37d
872c724
6a2e37d
872c724
 
 
 
 
 
 
 
6a2e37d
 
872c724
6a2e37d
 
872c724
 
 
 
6a2e37d
 
872c724
6a2e37d
 
 
872c724
6a2e37d
872c724
6a2e37d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
from refacer import Refacer
import argparse
import ngrok
import time

# Argument parser
parser = argparse.ArgumentParser(description='Refacer')
parser.add_argument("--max_num_faces", type=int, help="Max number of faces on UI", default=5)
parser.add_argument("--force_cpu", help="Force CPU mode", default=False, action="store_true")
parser.add_argument("--share_gradio", help="Share Gradio", default=False, action="store_true")
parser.add_argument("--server_name", type=str, help="Server IP address", default="127.0.0.1")
parser.add_argument("--server_port", type=int, help="Server port", default=7860)
parser.add_argument("--colab_performance", help="Use in colab for better performance", default=False, action="store_true")
parser.add_argument("--ngrok", type=str, help="Use ngrok", default=None)
parser.add_argument("--ngrok_region", type=str, help="ngrok region", default="us")
args = parser.parse_args()

# Initialize Refacer
print("Initializing Refacer...")
start_time = time.time()
refacer = Refacer(force_cpu=args.force_cpu, colab_performance=args.colab_performance)
num_faces = args.max_num_faces
print(f"Refacer initialized in {time.time() - start_time:.2f} seconds")

# Ngrok connection
def connect(token, port, options):
    account = None
    if token:
        if ':' in token:
            # token = authtoken:username:password
            token, username, password = token.split(':', 2)
            account = f"{username}:{password}"

    if not options.get('authtoken_from_env'):
        options['authtoken'] = token
    if account:
        options['basic_auth'] = account

    try:
        public_url = ngrok.connect(f"127.0.0.1:{port}", **options).url()
        print(f'ngrok connected to localhost:{port}! URL: {public_url}')
    except Exception as e:
        print(f'ngrok connection aborted: {e}')

# Run reface
def run(*vars):
    video_path = vars[0]
    origins = vars[1:(num_faces + 1)]
    destinations = vars[(num_faces + 1):(num_faces * 2) + 1]
    thresholds = vars[(num_faces * 2) + 1:]

    faces = []
    for k in range(num_faces):
        if origins[k] and destinations[k]:
            faces.append({
                'origin': origins[k],
                'destination': destinations[k],
                'threshold': thresholds[k]
            })

    return refacer.reface(video_path, faces)

# UI setup
origin = []
destination = []
thresholds = []

with gr.Blocks() as demo:
    with gr.Row():
        gr.Markdown("# Refacer")
    with gr.Row():
        video = gr.Video(label="Original video", format="mp4")
        video2 = gr.Video(label="Refaced video", interactive=False, format="mp4")

    for i in range(num_faces):
        with gr.Tab(f"Face #{i + 1}"):
            with gr.Row():
                origin.append(gr.Image(label="Face to replace"))
                destination.append(gr.Image(label="Destination face"))
            with gr.Row():
                thresholds.append(gr.Slider(label="Threshold", minimum=0.0, maximum=1.0, value=0.2))

    with gr.Row():
        button = gr.Button("Reface", variant="primary")

    button.click(fn=run, inputs=[video] + origin + destination + thresholds, outputs=[video2])

if args.ngrok:
    connect(args.ngrok, args.server_port, {'region': args.ngrok_region, 'authtoken_from_env': False})

# Launch demo
demo.queue().launch(show_error=True, share=args.share_gradio, server_name=args.server_name, server_port=args.server_port)