File size: 2,962 Bytes
f0de4e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
from tools import Inference, Matting, log
from omegaconf import OmegaConf
import os
import sys
import numpy as np
import torchvision.transforms.functional as tf
from PIL import Image

args = OmegaConf.load(os.path.join(f"./config/test.yaml"))

global_comp = None
global_mask = None

log("Model loading")
phnet = Inference(**args)
stylematte = Matting(**args)
log("Model loaded")


def harmonize(comp, mask):
    log("Inference started")
    if comp is None or mask is None:
        log("Empty source")
        return np.zeros((16, 16, 3))

    comp = comp.convert('RGB')
    mask = mask.convert('1')
    in_shape = comp.size[::-1]

    comp = tf.resize(comp, [args.image_size, args.image_size])
    mask = tf.resize(mask, [args.image_size, args.image_size])

    compt = tf.to_tensor(comp)
    maskt = tf.to_tensor(mask)
    res = phnet.harmonize(compt, maskt)
    res = tf.resize(res, in_shape)

    log("Inference finished")

    return np.uint8((res*255)[0].permute(1, 2, 0).numpy())


def extract_matte(img, back):
    mask, fg = stylematte.extract(img)
    fg_pil = Image.fromarray(np.uint8(fg))

    composite = fg + (1 - mask[:, :, None]) * \
        np.array(back.resize(mask.shape[::-1]))
    composite_pil = Image.fromarray(np.uint8(composite))

    global_comp = composite_pil
    global_mask = mask

    return [composite_pil, mask, fg_pil]


def css(height=3, scale=2):
    return f".output_image {{height: {height}rem !important; width: {scale}rem !important;}}"


with gr.Blocks() as demo:
    gr.Markdown(
        """
    # Welcome to portrait transfer demo app!
    Select source portrait image and new background.
    """)
    btn_compose = gr.Button(value="Compose")

    with gr.Row():
        input_ui = gr.Image(
            type="numpy", label='Source image to extract foreground')
        back_ui = gr.Image(type="pil", label='The new background')

    gr.Examples(
        examples=[["./assets/comp.jpg", "./assets/back.jpg"]],
        inputs=[input_ui, back_ui],
    )

    gr.Markdown(
        """
    ## Resulting alpha matte and extracted foreground.
    """)
    with gr.Row():
        matte_ui = gr.Image(type="pil", label='Alpha matte')
        fg_ui = gr.Image(type="pil", image_mode='RGBA',
                         label='Extracted foreground')

    gr.Markdown(
        """
    ## Click the button and compare the composite with the harmonized version.
    """)
    btn_harmonize = gr.Button(value="Harmonize composite")

    with gr.Row():
        composite_ui = gr.Image(type="pil", label='Composite')
        harmonized_ui = gr.Image(
            type="pil", label='Harmonized composite', css=css(3, 3))

    btn_compose.click(extract_matte, inputs=[input_ui, back_ui], outputs=[
                      composite_ui, matte_ui, fg_ui])
    btn_harmonize.click(harmonize, inputs=[
                        composite_ui, matte_ui], outputs=[harmonized_ui])


log("Interface created")
demo.launch(share=True)